Spaces:
Runtime error
Runtime error
Upload 41 files
Browse files- modules/artists.py +25 -0
- modules/bsrgan_model.py +78 -0
- modules/bsrgan_model_arch.py +102 -0
- modules/codeformer_model.py +140 -0
- modules/devices.py +68 -0
- modules/errors.py +10 -0
- modules/esrgam_model_arch.py +80 -0
- modules/esrgan_model.py +160 -0
- modules/extras.py +208 -0
- modules/face_restoration.py +19 -0
- modules/generation_parameters_copypaste.py +92 -0
- modules/gfpgan_model.py +115 -0
- modules/hypernetwork.py +88 -0
- modules/images.py +430 -0
- modules/img2img.py +135 -0
- modules/interrogate.py +168 -0
- modules/ldsr_model.py +56 -0
- modules/ldsr_model_arch.py +222 -0
- modules/lowvram.py +82 -0
- modules/masking.py +99 -0
- modules/memmon.py +85 -0
- modules/modelloader.py +153 -0
- modules/paths.py +38 -0
- modules/processing.py +688 -0
- modules/prompt_parser.py +352 -0
- modules/realesrgan_model.py +135 -0
- modules/safety.py +42 -0
- modules/scripts.py +201 -0
- modules/scunet_model.py +90 -0
- modules/scunet_model_arch.py +265 -0
- modules/sd_hijack.py +321 -0
- modules/sd_hijack_optimizations.py +169 -0
- modules/sd_models.py +192 -0
- modules/sd_samplers.py +380 -0
- modules/shared.py +366 -0
- modules/styles.py +92 -0
- modules/swinir_model.py +142 -0
- modules/swinir_model_arch.py +867 -0
- modules/txt2img.py +55 -0
- modules/ui.py +1414 -0
- modules/upscaler.py +121 -0
modules/artists.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import csv
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
Artist = namedtuple("Artist", ['name', 'weight', 'category'])
|
6 |
+
|
7 |
+
|
8 |
+
class ArtistsDatabase:
|
9 |
+
def __init__(self, filename):
|
10 |
+
self.cats = set()
|
11 |
+
self.artists = []
|
12 |
+
|
13 |
+
if not os.path.exists(filename):
|
14 |
+
return
|
15 |
+
|
16 |
+
with open(filename, "r", newline='', encoding="utf8") as file:
|
17 |
+
reader = csv.DictReader(file)
|
18 |
+
|
19 |
+
for row in reader:
|
20 |
+
artist = Artist(row["artist"], float(row["score"]), row["category"])
|
21 |
+
self.artists.append(artist)
|
22 |
+
self.cats.add(artist.category)
|
23 |
+
|
24 |
+
def categories(self):
|
25 |
+
return sorted(self.cats)
|
modules/bsrgan_model.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import PIL.Image
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from basicsr.utils.download_util import load_file_from_url
|
9 |
+
|
10 |
+
import modules.upscaler
|
11 |
+
from modules import devices, modelloader
|
12 |
+
from modules.bsrgan_model_arch import RRDBNet
|
13 |
+
from modules.paths import models_path
|
14 |
+
|
15 |
+
|
16 |
+
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
17 |
+
def __init__(self, dirname):
|
18 |
+
self.name = "BSRGAN"
|
19 |
+
self.model_path = os.path.join(models_path, self.name)
|
20 |
+
self.model_name = "BSRGAN 4x"
|
21 |
+
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
22 |
+
self.user_path = dirname
|
23 |
+
super().__init__()
|
24 |
+
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
25 |
+
scalers = []
|
26 |
+
if len(model_paths) == 0:
|
27 |
+
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
28 |
+
scalers.append(scaler_data)
|
29 |
+
for file in model_paths:
|
30 |
+
if "http" in file:
|
31 |
+
name = self.model_name
|
32 |
+
else:
|
33 |
+
name = modelloader.friendly_name(file)
|
34 |
+
try:
|
35 |
+
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
36 |
+
scalers.append(scaler_data)
|
37 |
+
except Exception:
|
38 |
+
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
39 |
+
print(traceback.format_exc(), file=sys.stderr)
|
40 |
+
self.scalers = scalers
|
41 |
+
|
42 |
+
def do_upscale(self, img: PIL.Image, selected_file):
|
43 |
+
torch.cuda.empty_cache()
|
44 |
+
model = self.load_model(selected_file)
|
45 |
+
if model is None:
|
46 |
+
return img
|
47 |
+
model.to(devices.device_bsrgan)
|
48 |
+
torch.cuda.empty_cache()
|
49 |
+
img = np.array(img)
|
50 |
+
img = img[:, :, ::-1]
|
51 |
+
img = np.moveaxis(img, 2, 0) / 255
|
52 |
+
img = torch.from_numpy(img).float()
|
53 |
+
img = img.unsqueeze(0).to(devices.device_bsrgan)
|
54 |
+
with torch.no_grad():
|
55 |
+
output = model(img)
|
56 |
+
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
57 |
+
output = 255. * np.moveaxis(output, 0, 2)
|
58 |
+
output = output.astype(np.uint8)
|
59 |
+
output = output[:, :, ::-1]
|
60 |
+
torch.cuda.empty_cache()
|
61 |
+
return PIL.Image.fromarray(output, 'RGB')
|
62 |
+
|
63 |
+
def load_model(self, path: str):
|
64 |
+
if "http" in path:
|
65 |
+
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
66 |
+
progress=True)
|
67 |
+
else:
|
68 |
+
filename = path
|
69 |
+
if not os.path.exists(filename) or filename is None:
|
70 |
+
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
|
71 |
+
return None
|
72 |
+
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
|
73 |
+
model.load_state_dict(torch.load(filename), strict=True)
|
74 |
+
model.eval()
|
75 |
+
for k, v in model.named_parameters():
|
76 |
+
v.requires_grad = False
|
77 |
+
return model
|
78 |
+
|
modules/bsrgan_model_arch.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.nn.init as init
|
6 |
+
|
7 |
+
|
8 |
+
def initialize_weights(net_l, scale=1):
|
9 |
+
if not isinstance(net_l, list):
|
10 |
+
net_l = [net_l]
|
11 |
+
for net in net_l:
|
12 |
+
for m in net.modules():
|
13 |
+
if isinstance(m, nn.Conv2d):
|
14 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
15 |
+
m.weight.data *= scale # for residual block
|
16 |
+
if m.bias is not None:
|
17 |
+
m.bias.data.zero_()
|
18 |
+
elif isinstance(m, nn.Linear):
|
19 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
20 |
+
m.weight.data *= scale
|
21 |
+
if m.bias is not None:
|
22 |
+
m.bias.data.zero_()
|
23 |
+
elif isinstance(m, nn.BatchNorm2d):
|
24 |
+
init.constant_(m.weight, 1)
|
25 |
+
init.constant_(m.bias.data, 0.0)
|
26 |
+
|
27 |
+
|
28 |
+
def make_layer(block, n_layers):
|
29 |
+
layers = []
|
30 |
+
for _ in range(n_layers):
|
31 |
+
layers.append(block())
|
32 |
+
return nn.Sequential(*layers)
|
33 |
+
|
34 |
+
|
35 |
+
class ResidualDenseBlock_5C(nn.Module):
|
36 |
+
def __init__(self, nf=64, gc=32, bias=True):
|
37 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
38 |
+
# gc: growth channel, i.e. intermediate channels
|
39 |
+
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
40 |
+
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
41 |
+
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
42 |
+
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
43 |
+
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
44 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
45 |
+
|
46 |
+
# initialization
|
47 |
+
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x1 = self.lrelu(self.conv1(x))
|
51 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
52 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
53 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
54 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
55 |
+
return x5 * 0.2 + x
|
56 |
+
|
57 |
+
|
58 |
+
class RRDB(nn.Module):
|
59 |
+
'''Residual in Residual Dense Block'''
|
60 |
+
|
61 |
+
def __init__(self, nf, gc=32):
|
62 |
+
super(RRDB, self).__init__()
|
63 |
+
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
64 |
+
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
65 |
+
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
out = self.RDB1(x)
|
69 |
+
out = self.RDB2(out)
|
70 |
+
out = self.RDB3(out)
|
71 |
+
return out * 0.2 + x
|
72 |
+
|
73 |
+
|
74 |
+
class RRDBNet(nn.Module):
|
75 |
+
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
76 |
+
super(RRDBNet, self).__init__()
|
77 |
+
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
78 |
+
self.sf = sf
|
79 |
+
|
80 |
+
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
81 |
+
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
82 |
+
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
83 |
+
#### upsampling
|
84 |
+
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
85 |
+
if self.sf==4:
|
86 |
+
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
87 |
+
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
88 |
+
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
89 |
+
|
90 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
fea = self.conv_first(x)
|
94 |
+
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
95 |
+
fea = fea + trunk
|
96 |
+
|
97 |
+
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
98 |
+
if self.sf==4:
|
99 |
+
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
100 |
+
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
101 |
+
|
102 |
+
return out
|
modules/codeformer_model.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import modules.face_restoration
|
9 |
+
import modules.shared
|
10 |
+
from modules import shared, devices, modelloader
|
11 |
+
from modules.paths import script_path, models_path
|
12 |
+
|
13 |
+
# codeformer people made a choice to include modified basicsr library to their project which makes
|
14 |
+
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
15 |
+
# I am making a choice to include some files from codeformer to work around this issue.
|
16 |
+
model_dir = "Codeformer"
|
17 |
+
model_path = os.path.join(models_path, model_dir)
|
18 |
+
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
19 |
+
|
20 |
+
have_codeformer = False
|
21 |
+
codeformer = None
|
22 |
+
|
23 |
+
|
24 |
+
def setup_model(dirname):
|
25 |
+
global model_path
|
26 |
+
if not os.path.exists(model_path):
|
27 |
+
os.makedirs(model_path)
|
28 |
+
|
29 |
+
path = modules.paths.paths.get("CodeFormer", None)
|
30 |
+
if path is None:
|
31 |
+
return
|
32 |
+
|
33 |
+
try:
|
34 |
+
from torchvision.transforms.functional import normalize
|
35 |
+
from modules.codeformer.codeformer_arch import CodeFormer
|
36 |
+
from basicsr.utils.download_util import load_file_from_url
|
37 |
+
from basicsr.utils import imwrite, img2tensor, tensor2img
|
38 |
+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
39 |
+
from modules.shared import cmd_opts
|
40 |
+
|
41 |
+
net_class = CodeFormer
|
42 |
+
|
43 |
+
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
|
44 |
+
def name(self):
|
45 |
+
return "CodeFormer"
|
46 |
+
|
47 |
+
def __init__(self, dirname):
|
48 |
+
self.net = None
|
49 |
+
self.face_helper = None
|
50 |
+
self.cmd_dir = dirname
|
51 |
+
|
52 |
+
def create_models(self):
|
53 |
+
|
54 |
+
if self.net is not None and self.face_helper is not None:
|
55 |
+
self.net.to(devices.device_codeformer)
|
56 |
+
return self.net, self.face_helper
|
57 |
+
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
|
58 |
+
if len(model_paths) != 0:
|
59 |
+
ckpt_path = model_paths[0]
|
60 |
+
else:
|
61 |
+
print("Unable to load codeformer model.")
|
62 |
+
return None, None
|
63 |
+
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
64 |
+
checkpoint = torch.load(ckpt_path)['params_ema']
|
65 |
+
net.load_state_dict(checkpoint)
|
66 |
+
net.eval()
|
67 |
+
|
68 |
+
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
|
69 |
+
|
70 |
+
self.net = net
|
71 |
+
self.face_helper = face_helper
|
72 |
+
|
73 |
+
return net, face_helper
|
74 |
+
|
75 |
+
def send_model_to(self, device):
|
76 |
+
self.net.to(device)
|
77 |
+
self.face_helper.face_det.to(device)
|
78 |
+
self.face_helper.face_parse.to(device)
|
79 |
+
|
80 |
+
def restore(self, np_image, w=None):
|
81 |
+
np_image = np_image[:, :, ::-1]
|
82 |
+
|
83 |
+
original_resolution = np_image.shape[0:2]
|
84 |
+
|
85 |
+
self.create_models()
|
86 |
+
if self.net is None or self.face_helper is None:
|
87 |
+
return np_image
|
88 |
+
|
89 |
+
self.send_model_to(devices.device_codeformer)
|
90 |
+
|
91 |
+
self.face_helper.clean_all()
|
92 |
+
self.face_helper.read_image(np_image)
|
93 |
+
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
94 |
+
self.face_helper.align_warp_face()
|
95 |
+
|
96 |
+
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
97 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
98 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
99 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
100 |
+
|
101 |
+
try:
|
102 |
+
with torch.no_grad():
|
103 |
+
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
104 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
105 |
+
del output
|
106 |
+
torch.cuda.empty_cache()
|
107 |
+
except Exception as error:
|
108 |
+
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
|
109 |
+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
110 |
+
|
111 |
+
restored_face = restored_face.astype('uint8')
|
112 |
+
self.face_helper.add_restored_face(restored_face)
|
113 |
+
|
114 |
+
self.face_helper.get_inverse_affine(None)
|
115 |
+
|
116 |
+
restored_img = self.face_helper.paste_faces_to_input_image()
|
117 |
+
restored_img = restored_img[:, :, ::-1]
|
118 |
+
|
119 |
+
if original_resolution != restored_img.shape[0:2]:
|
120 |
+
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
121 |
+
|
122 |
+
self.face_helper.clean_all()
|
123 |
+
|
124 |
+
if shared.opts.face_restoration_unload:
|
125 |
+
self.send_model_to(devices.cpu)
|
126 |
+
|
127 |
+
return restored_img
|
128 |
+
|
129 |
+
global have_codeformer
|
130 |
+
have_codeformer = True
|
131 |
+
|
132 |
+
global codeformer
|
133 |
+
codeformer = FaceRestorerCodeFormer(dirname)
|
134 |
+
shared.face_restorers.append(codeformer)
|
135 |
+
|
136 |
+
except Exception:
|
137 |
+
print("Error setting up CodeFormer:", file=sys.stderr)
|
138 |
+
print(traceback.format_exc(), file=sys.stderr)
|
139 |
+
|
140 |
+
# sys.path = stored_sys_path
|
modules/devices.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from modules import errors
|
6 |
+
|
7 |
+
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
8 |
+
has_mps = getattr(torch, 'has_mps', False)
|
9 |
+
|
10 |
+
cpu = torch.device("cpu")
|
11 |
+
|
12 |
+
|
13 |
+
def get_optimal_device():
|
14 |
+
if torch.cuda.is_available():
|
15 |
+
return torch.device("cuda")
|
16 |
+
|
17 |
+
if has_mps:
|
18 |
+
return torch.device("mps")
|
19 |
+
|
20 |
+
return cpu
|
21 |
+
|
22 |
+
|
23 |
+
def torch_gc():
|
24 |
+
if torch.cuda.is_available():
|
25 |
+
torch.cuda.empty_cache()
|
26 |
+
torch.cuda.ipc_collect()
|
27 |
+
|
28 |
+
|
29 |
+
def enable_tf32():
|
30 |
+
if torch.cuda.is_available():
|
31 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
32 |
+
torch.backends.cudnn.allow_tf32 = True
|
33 |
+
|
34 |
+
|
35 |
+
errors.run(enable_tf32, "Enabling TF32")
|
36 |
+
|
37 |
+
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
38 |
+
dtype = torch.float16
|
39 |
+
|
40 |
+
def randn(seed, shape):
|
41 |
+
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
42 |
+
if device.type == 'mps':
|
43 |
+
generator = torch.Generator(device=cpu)
|
44 |
+
generator.manual_seed(seed)
|
45 |
+
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
46 |
+
return noise
|
47 |
+
|
48 |
+
torch.manual_seed(seed)
|
49 |
+
return torch.randn(shape, device=device)
|
50 |
+
|
51 |
+
|
52 |
+
def randn_without_seed(shape):
|
53 |
+
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
54 |
+
if device.type == 'mps':
|
55 |
+
generator = torch.Generator(device=cpu)
|
56 |
+
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
57 |
+
return noise
|
58 |
+
|
59 |
+
return torch.randn(shape, device=device)
|
60 |
+
|
61 |
+
|
62 |
+
def autocast():
|
63 |
+
from modules import shared
|
64 |
+
|
65 |
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
66 |
+
return contextlib.nullcontext()
|
67 |
+
|
68 |
+
return torch.autocast("cuda")
|
modules/errors.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import traceback
|
3 |
+
|
4 |
+
|
5 |
+
def run(code, task):
|
6 |
+
try:
|
7 |
+
code()
|
8 |
+
except Exception as e:
|
9 |
+
print(f"{task}: {type(e).__name__}", file=sys.stderr)
|
10 |
+
print(traceback.format_exc(), file=sys.stderr)
|
modules/esrgam_model_arch.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this file is taken from https://github.com/xinntao/ESRGAN
|
2 |
+
|
3 |
+
import functools
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def make_layer(block, n_layers):
|
10 |
+
layers = []
|
11 |
+
for _ in range(n_layers):
|
12 |
+
layers.append(block())
|
13 |
+
return nn.Sequential(*layers)
|
14 |
+
|
15 |
+
|
16 |
+
class ResidualDenseBlock_5C(nn.Module):
|
17 |
+
def __init__(self, nf=64, gc=32, bias=True):
|
18 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
19 |
+
# gc: growth channel, i.e. intermediate channels
|
20 |
+
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
21 |
+
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
22 |
+
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
23 |
+
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
24 |
+
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
25 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
26 |
+
|
27 |
+
# initialization
|
28 |
+
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
x1 = self.lrelu(self.conv1(x))
|
32 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
33 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
34 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
35 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
36 |
+
return x5 * 0.2 + x
|
37 |
+
|
38 |
+
|
39 |
+
class RRDB(nn.Module):
|
40 |
+
'''Residual in Residual Dense Block'''
|
41 |
+
|
42 |
+
def __init__(self, nf, gc=32):
|
43 |
+
super(RRDB, self).__init__()
|
44 |
+
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
45 |
+
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
46 |
+
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
out = self.RDB1(x)
|
50 |
+
out = self.RDB2(out)
|
51 |
+
out = self.RDB3(out)
|
52 |
+
return out * 0.2 + x
|
53 |
+
|
54 |
+
|
55 |
+
class RRDBNet(nn.Module):
|
56 |
+
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
57 |
+
super(RRDBNet, self).__init__()
|
58 |
+
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
59 |
+
|
60 |
+
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
61 |
+
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
62 |
+
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
63 |
+
#### upsampling
|
64 |
+
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
65 |
+
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
66 |
+
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
67 |
+
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
68 |
+
|
69 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
fea = self.conv_first(x)
|
73 |
+
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
74 |
+
fea = fea + trunk
|
75 |
+
|
76 |
+
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
77 |
+
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
78 |
+
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
79 |
+
|
80 |
+
return out
|
modules/esrgan_model.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from basicsr.utils.download_util import load_file_from_url
|
7 |
+
|
8 |
+
import modules.esrgam_model_arch as arch
|
9 |
+
from modules import shared, modelloader, images, devices
|
10 |
+
from modules.paths import models_path
|
11 |
+
from modules.upscaler import Upscaler, UpscalerData
|
12 |
+
from modules.shared import opts
|
13 |
+
|
14 |
+
|
15 |
+
def fix_model_layers(crt_model, pretrained_net):
|
16 |
+
# this code is adapted from https://github.com/xinntao/ESRGAN
|
17 |
+
if 'conv_first.weight' in pretrained_net:
|
18 |
+
return pretrained_net
|
19 |
+
|
20 |
+
if 'model.0.weight' not in pretrained_net:
|
21 |
+
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
|
22 |
+
if is_realesrgan:
|
23 |
+
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
|
24 |
+
else:
|
25 |
+
raise Exception("The file is not a ESRGAN model.")
|
26 |
+
|
27 |
+
crt_net = crt_model.state_dict()
|
28 |
+
load_net_clean = {}
|
29 |
+
for k, v in pretrained_net.items():
|
30 |
+
if k.startswith('module.'):
|
31 |
+
load_net_clean[k[7:]] = v
|
32 |
+
else:
|
33 |
+
load_net_clean[k] = v
|
34 |
+
pretrained_net = load_net_clean
|
35 |
+
|
36 |
+
tbd = []
|
37 |
+
for k, v in crt_net.items():
|
38 |
+
tbd.append(k)
|
39 |
+
|
40 |
+
# directly copy
|
41 |
+
for k, v in crt_net.items():
|
42 |
+
if k in pretrained_net and pretrained_net[k].size() == v.size():
|
43 |
+
crt_net[k] = pretrained_net[k]
|
44 |
+
tbd.remove(k)
|
45 |
+
|
46 |
+
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
|
47 |
+
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
|
48 |
+
|
49 |
+
for k in tbd.copy():
|
50 |
+
if 'RDB' in k:
|
51 |
+
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
52 |
+
if '.weight' in k:
|
53 |
+
ori_k = ori_k.replace('.weight', '.0.weight')
|
54 |
+
elif '.bias' in k:
|
55 |
+
ori_k = ori_k.replace('.bias', '.0.bias')
|
56 |
+
crt_net[k] = pretrained_net[ori_k]
|
57 |
+
tbd.remove(k)
|
58 |
+
|
59 |
+
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
|
60 |
+
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
|
61 |
+
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
|
62 |
+
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
|
63 |
+
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
|
64 |
+
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
|
65 |
+
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
|
66 |
+
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
|
67 |
+
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
|
68 |
+
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
|
69 |
+
|
70 |
+
return crt_net
|
71 |
+
|
72 |
+
class UpscalerESRGAN(Upscaler):
|
73 |
+
def __init__(self, dirname):
|
74 |
+
self.name = "ESRGAN"
|
75 |
+
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
|
76 |
+
self.model_name = "ESRGAN_4x"
|
77 |
+
self.scalers = []
|
78 |
+
self.user_path = dirname
|
79 |
+
self.model_path = os.path.join(models_path, self.name)
|
80 |
+
super().__init__()
|
81 |
+
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
82 |
+
scalers = []
|
83 |
+
if len(model_paths) == 0:
|
84 |
+
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
85 |
+
scalers.append(scaler_data)
|
86 |
+
for file in model_paths:
|
87 |
+
if "http" in file:
|
88 |
+
name = self.model_name
|
89 |
+
else:
|
90 |
+
name = modelloader.friendly_name(file)
|
91 |
+
|
92 |
+
scaler_data = UpscalerData(name, file, self, 4)
|
93 |
+
self.scalers.append(scaler_data)
|
94 |
+
|
95 |
+
def do_upscale(self, img, selected_model):
|
96 |
+
model = self.load_model(selected_model)
|
97 |
+
if model is None:
|
98 |
+
return img
|
99 |
+
model.to(devices.device_esrgan)
|
100 |
+
img = esrgan_upscale(model, img)
|
101 |
+
return img
|
102 |
+
|
103 |
+
def load_model(self, path: str):
|
104 |
+
if "http" in path:
|
105 |
+
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
106 |
+
file_name="%s.pth" % self.model_name,
|
107 |
+
progress=True)
|
108 |
+
else:
|
109 |
+
filename = path
|
110 |
+
if not os.path.exists(filename) or filename is None:
|
111 |
+
print("Unable to load %s from %s" % (self.model_path, filename))
|
112 |
+
return None
|
113 |
+
|
114 |
+
pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None)
|
115 |
+
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
116 |
+
|
117 |
+
pretrained_net = fix_model_layers(crt_model, pretrained_net)
|
118 |
+
crt_model.load_state_dict(pretrained_net)
|
119 |
+
crt_model.eval()
|
120 |
+
|
121 |
+
return crt_model
|
122 |
+
|
123 |
+
|
124 |
+
def upscale_without_tiling(model, img):
|
125 |
+
img = np.array(img)
|
126 |
+
img = img[:, :, ::-1]
|
127 |
+
img = np.moveaxis(img, 2, 0) / 255
|
128 |
+
img = torch.from_numpy(img).float()
|
129 |
+
img = img.unsqueeze(0).to(devices.device_esrgan)
|
130 |
+
with torch.no_grad():
|
131 |
+
output = model(img)
|
132 |
+
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
133 |
+
output = 255. * np.moveaxis(output, 0, 2)
|
134 |
+
output = output.astype(np.uint8)
|
135 |
+
output = output[:, :, ::-1]
|
136 |
+
return Image.fromarray(output, 'RGB')
|
137 |
+
|
138 |
+
|
139 |
+
def esrgan_upscale(model, img):
|
140 |
+
if opts.ESRGAN_tile == 0:
|
141 |
+
return upscale_without_tiling(model, img)
|
142 |
+
|
143 |
+
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
144 |
+
newtiles = []
|
145 |
+
scale_factor = 1
|
146 |
+
|
147 |
+
for y, h, row in grid.tiles:
|
148 |
+
newrow = []
|
149 |
+
for tiledata in row:
|
150 |
+
x, w, tile = tiledata
|
151 |
+
|
152 |
+
output = upscale_without_tiling(model, tile)
|
153 |
+
scale_factor = output.width // tile.width
|
154 |
+
|
155 |
+
newrow.append([x * scale_factor, w * scale_factor, output])
|
156 |
+
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
157 |
+
|
158 |
+
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
159 |
+
output = images.combine_grid(newgrid)
|
160 |
+
return output
|
modules/extras.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
from modules import processing, shared, images, devices, sd_models
|
10 |
+
from modules.shared import opts
|
11 |
+
import modules.gfpgan_model
|
12 |
+
from modules.ui import plaintext_to_html
|
13 |
+
import modules.codeformer_model
|
14 |
+
import piexif
|
15 |
+
import piexif.helper
|
16 |
+
import gradio as gr
|
17 |
+
|
18 |
+
|
19 |
+
cached_images = {}
|
20 |
+
|
21 |
+
|
22 |
+
def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
23 |
+
devices.torch_gc()
|
24 |
+
|
25 |
+
imageArr = []
|
26 |
+
# Also keep track of original file names
|
27 |
+
imageNameArr = []
|
28 |
+
|
29 |
+
if extras_mode == 1:
|
30 |
+
#convert file to pillow image
|
31 |
+
for img in image_folder:
|
32 |
+
image = Image.fromarray(np.array(Image.open(img)))
|
33 |
+
imageArr.append(image)
|
34 |
+
imageNameArr.append(os.path.splitext(img.orig_name)[0])
|
35 |
+
else:
|
36 |
+
imageArr.append(image)
|
37 |
+
imageNameArr.append(None)
|
38 |
+
|
39 |
+
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
40 |
+
|
41 |
+
outputs = []
|
42 |
+
for image, image_name in zip(imageArr, imageNameArr):
|
43 |
+
if image is None:
|
44 |
+
return outputs, "Please select an input image.", ''
|
45 |
+
existing_pnginfo = image.info or {}
|
46 |
+
|
47 |
+
image = image.convert("RGB")
|
48 |
+
info = ""
|
49 |
+
|
50 |
+
if gfpgan_visibility > 0:
|
51 |
+
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
52 |
+
res = Image.fromarray(restored_img)
|
53 |
+
|
54 |
+
if gfpgan_visibility < 1.0:
|
55 |
+
res = Image.blend(image, res, gfpgan_visibility)
|
56 |
+
|
57 |
+
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
|
58 |
+
image = res
|
59 |
+
|
60 |
+
if codeformer_visibility > 0:
|
61 |
+
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
62 |
+
res = Image.fromarray(restored_img)
|
63 |
+
|
64 |
+
if codeformer_visibility < 1.0:
|
65 |
+
res = Image.blend(image, res, codeformer_visibility)
|
66 |
+
|
67 |
+
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
68 |
+
image = res
|
69 |
+
|
70 |
+
if upscaling_resize != 1.0:
|
71 |
+
def upscale(image, scaler_index, resize):
|
72 |
+
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
73 |
+
pixels = tuple(np.array(small).flatten().tolist())
|
74 |
+
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
|
75 |
+
|
76 |
+
c = cached_images.get(key)
|
77 |
+
if c is None:
|
78 |
+
upscaler = shared.sd_upscalers[scaler_index]
|
79 |
+
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
80 |
+
cached_images[key] = c
|
81 |
+
|
82 |
+
return c
|
83 |
+
|
84 |
+
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
85 |
+
res = upscale(image, extras_upscaler_1, upscaling_resize)
|
86 |
+
|
87 |
+
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
88 |
+
res2 = upscale(image, extras_upscaler_2, upscaling_resize)
|
89 |
+
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
90 |
+
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
91 |
+
|
92 |
+
image = res
|
93 |
+
|
94 |
+
while len(cached_images) > 2:
|
95 |
+
del cached_images[next(iter(cached_images.keys()))]
|
96 |
+
|
97 |
+
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
98 |
+
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
|
99 |
+
forced_filename=image_name if opts.use_original_name_batch else None)
|
100 |
+
|
101 |
+
outputs.append(image)
|
102 |
+
|
103 |
+
devices.torch_gc()
|
104 |
+
|
105 |
+
return outputs, plaintext_to_html(info), ''
|
106 |
+
|
107 |
+
|
108 |
+
def run_pnginfo(image):
|
109 |
+
if image is None:
|
110 |
+
return '', '', ''
|
111 |
+
|
112 |
+
items = image.info
|
113 |
+
geninfo = ''
|
114 |
+
|
115 |
+
if "exif" in image.info:
|
116 |
+
exif = piexif.load(image.info["exif"])
|
117 |
+
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
118 |
+
try:
|
119 |
+
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
120 |
+
except ValueError:
|
121 |
+
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
122 |
+
|
123 |
+
items['exif comment'] = exif_comment
|
124 |
+
geninfo = exif_comment
|
125 |
+
|
126 |
+
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
127 |
+
'loop', 'background', 'timestamp', 'duration']:
|
128 |
+
items.pop(field, None)
|
129 |
+
|
130 |
+
geninfo = items.get('parameters', geninfo)
|
131 |
+
|
132 |
+
info = ''
|
133 |
+
for key, text in items.items():
|
134 |
+
info += f"""
|
135 |
+
<div>
|
136 |
+
<p><b>{plaintext_to_html(str(key))}</b></p>
|
137 |
+
<p>{plaintext_to_html(str(text))}</p>
|
138 |
+
</div>
|
139 |
+
""".strip()+"\n"
|
140 |
+
|
141 |
+
if len(info) == 0:
|
142 |
+
message = "Nothing found in the image."
|
143 |
+
info = f"<div><p>{message}<p></div>"
|
144 |
+
|
145 |
+
return '', geninfo, info
|
146 |
+
|
147 |
+
|
148 |
+
def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
|
149 |
+
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
|
150 |
+
def weighted_sum(theta0, theta1, alpha):
|
151 |
+
return ((1 - alpha) * theta0) + (alpha * theta1)
|
152 |
+
|
153 |
+
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
|
154 |
+
def sigmoid(theta0, theta1, alpha):
|
155 |
+
alpha = alpha * alpha * (3 - (2 * alpha))
|
156 |
+
return theta0 + ((theta1 - theta0) * alpha)
|
157 |
+
|
158 |
+
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
|
159 |
+
def inv_sigmoid(theta0, theta1, alpha):
|
160 |
+
import math
|
161 |
+
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
|
162 |
+
return theta0 + ((theta1 - theta0) * alpha)
|
163 |
+
|
164 |
+
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
165 |
+
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
166 |
+
|
167 |
+
print(f"Loading {primary_model_info.filename}...")
|
168 |
+
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
|
169 |
+
|
170 |
+
print(f"Loading {secondary_model_info.filename}...")
|
171 |
+
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
172 |
+
|
173 |
+
theta_0 = primary_model['state_dict']
|
174 |
+
theta_1 = secondary_model['state_dict']
|
175 |
+
|
176 |
+
theta_funcs = {
|
177 |
+
"Weighted Sum": weighted_sum,
|
178 |
+
"Sigmoid": sigmoid,
|
179 |
+
"Inverse Sigmoid": inv_sigmoid,
|
180 |
+
}
|
181 |
+
theta_func = theta_funcs[interp_method]
|
182 |
+
|
183 |
+
print(f"Merging...")
|
184 |
+
for key in tqdm.tqdm(theta_0.keys()):
|
185 |
+
if 'model' in key and key in theta_1:
|
186 |
+
theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
|
187 |
+
if save_as_half:
|
188 |
+
theta_0[key] = theta_0[key].half()
|
189 |
+
|
190 |
+
for key in theta_1.keys():
|
191 |
+
if 'model' in key and key not in theta_0:
|
192 |
+
theta_0[key] = theta_1[key]
|
193 |
+
if save_as_half:
|
194 |
+
theta_0[key] = theta_0[key].half()
|
195 |
+
|
196 |
+
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
197 |
+
|
198 |
+
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
199 |
+
filename = filename if custom_name == '' else (custom_name + '.ckpt')
|
200 |
+
output_modelname = os.path.join(ckpt_dir, filename)
|
201 |
+
|
202 |
+
print(f"Saving to {output_modelname}...")
|
203 |
+
torch.save(primary_model, output_modelname)
|
204 |
+
|
205 |
+
sd_models.list_models()
|
206 |
+
|
207 |
+
print(f"Checkpoint saved.")
|
208 |
+
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
|
modules/face_restoration.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import shared
|
2 |
+
|
3 |
+
|
4 |
+
class FaceRestoration:
|
5 |
+
def name(self):
|
6 |
+
return "None"
|
7 |
+
|
8 |
+
def restore(self, np_image):
|
9 |
+
return np_image
|
10 |
+
|
11 |
+
|
12 |
+
def restore_faces(np_image):
|
13 |
+
face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
|
14 |
+
if len(face_restorers) == 0:
|
15 |
+
return np_image
|
16 |
+
|
17 |
+
face_restorer = face_restorers[0]
|
18 |
+
|
19 |
+
return face_restorer.restore(np_image)
|
modules/generation_parameters_copypaste.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
|
5 |
+
re_param = re.compile(re_param_code)
|
6 |
+
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
7 |
+
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
8 |
+
type_of_gr_update = type(gr.update())
|
9 |
+
|
10 |
+
|
11 |
+
def parse_generation_parameters(x: str):
|
12 |
+
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
13 |
+
```
|
14 |
+
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
|
15 |
+
Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
|
16 |
+
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
|
17 |
+
```
|
18 |
+
|
19 |
+
returns a dict with field values
|
20 |
+
"""
|
21 |
+
|
22 |
+
res = {}
|
23 |
+
|
24 |
+
prompt = ""
|
25 |
+
negative_prompt = ""
|
26 |
+
|
27 |
+
done_with_prompt = False
|
28 |
+
|
29 |
+
*lines, lastline = x.strip().split("\n")
|
30 |
+
if not re_params.match(lastline):
|
31 |
+
lines.append(lastline)
|
32 |
+
lastline = ''
|
33 |
+
|
34 |
+
for i, line in enumerate(lines):
|
35 |
+
line = line.strip()
|
36 |
+
if line.startswith("Negative prompt:"):
|
37 |
+
done_with_prompt = True
|
38 |
+
line = line[16:].strip()
|
39 |
+
|
40 |
+
if done_with_prompt:
|
41 |
+
negative_prompt += ("" if negative_prompt == "" else "\n") + line
|
42 |
+
else:
|
43 |
+
prompt += ("" if prompt == "" else "\n") + line
|
44 |
+
|
45 |
+
if len(prompt) > 0:
|
46 |
+
res["Prompt"] = prompt
|
47 |
+
|
48 |
+
if len(negative_prompt) > 0:
|
49 |
+
res["Negative prompt"] = negative_prompt
|
50 |
+
|
51 |
+
for k, v in re_param.findall(lastline):
|
52 |
+
m = re_imagesize.match(v)
|
53 |
+
if m is not None:
|
54 |
+
res[k+"-1"] = m.group(1)
|
55 |
+
res[k+"-2"] = m.group(2)
|
56 |
+
else:
|
57 |
+
res[k] = v
|
58 |
+
|
59 |
+
return res
|
60 |
+
|
61 |
+
|
62 |
+
def connect_paste(button, paste_fields, input_comp, js=None):
|
63 |
+
def paste_func(prompt):
|
64 |
+
params = parse_generation_parameters(prompt)
|
65 |
+
res = []
|
66 |
+
|
67 |
+
for output, key in paste_fields:
|
68 |
+
if callable(key):
|
69 |
+
v = key(params)
|
70 |
+
else:
|
71 |
+
v = params.get(key, None)
|
72 |
+
|
73 |
+
if v is None:
|
74 |
+
res.append(gr.update())
|
75 |
+
elif isinstance(v, type_of_gr_update):
|
76 |
+
res.append(v)
|
77 |
+
else:
|
78 |
+
try:
|
79 |
+
valtype = type(output.value)
|
80 |
+
val = valtype(v)
|
81 |
+
res.append(gr.update(value=val))
|
82 |
+
except Exception:
|
83 |
+
res.append(gr.update())
|
84 |
+
|
85 |
+
return res
|
86 |
+
|
87 |
+
button.click(
|
88 |
+
fn=paste_func,
|
89 |
+
_js=js,
|
90 |
+
inputs=[input_comp],
|
91 |
+
outputs=[x[0] for x in paste_fields],
|
92 |
+
)
|
modules/gfpgan_model.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import facexlib
|
6 |
+
import gfpgan
|
7 |
+
|
8 |
+
import modules.face_restoration
|
9 |
+
from modules import shared, devices, modelloader
|
10 |
+
from modules.paths import models_path
|
11 |
+
|
12 |
+
model_dir = "GFPGAN"
|
13 |
+
user_path = None
|
14 |
+
model_path = os.path.join(models_path, model_dir)
|
15 |
+
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
16 |
+
have_gfpgan = False
|
17 |
+
loaded_gfpgan_model = None
|
18 |
+
|
19 |
+
|
20 |
+
def gfpgann():
|
21 |
+
global loaded_gfpgan_model
|
22 |
+
global model_path
|
23 |
+
if loaded_gfpgan_model is not None:
|
24 |
+
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
25 |
+
return loaded_gfpgan_model
|
26 |
+
|
27 |
+
if gfpgan_constructor is None:
|
28 |
+
return None
|
29 |
+
|
30 |
+
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
31 |
+
if len(models) == 1 and "http" in models[0]:
|
32 |
+
model_file = models[0]
|
33 |
+
elif len(models) != 0:
|
34 |
+
latest_file = max(models, key=os.path.getctime)
|
35 |
+
model_file = latest_file
|
36 |
+
else:
|
37 |
+
print("Unable to load gfpgan model!")
|
38 |
+
return None
|
39 |
+
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
40 |
+
loaded_gfpgan_model = model
|
41 |
+
|
42 |
+
return model
|
43 |
+
|
44 |
+
|
45 |
+
def send_model_to(model, device):
|
46 |
+
model.gfpgan.to(device)
|
47 |
+
model.face_helper.face_det.to(device)
|
48 |
+
model.face_helper.face_parse.to(device)
|
49 |
+
|
50 |
+
|
51 |
+
def gfpgan_fix_faces(np_image):
|
52 |
+
model = gfpgann()
|
53 |
+
if model is None:
|
54 |
+
return np_image
|
55 |
+
|
56 |
+
send_model_to(model, devices.device_gfpgan)
|
57 |
+
|
58 |
+
np_image_bgr = np_image[:, :, ::-1]
|
59 |
+
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
60 |
+
np_image = gfpgan_output_bgr[:, :, ::-1]
|
61 |
+
|
62 |
+
model.face_helper.clean_all()
|
63 |
+
|
64 |
+
if shared.opts.face_restoration_unload:
|
65 |
+
send_model_to(model, devices.cpu)
|
66 |
+
|
67 |
+
return np_image
|
68 |
+
|
69 |
+
|
70 |
+
gfpgan_constructor = None
|
71 |
+
|
72 |
+
|
73 |
+
def setup_model(dirname):
|
74 |
+
global model_path
|
75 |
+
if not os.path.exists(model_path):
|
76 |
+
os.makedirs(model_path)
|
77 |
+
|
78 |
+
try:
|
79 |
+
from gfpgan import GFPGANer
|
80 |
+
from facexlib import detection, parsing
|
81 |
+
global user_path
|
82 |
+
global have_gfpgan
|
83 |
+
global gfpgan_constructor
|
84 |
+
|
85 |
+
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
86 |
+
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
87 |
+
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
88 |
+
|
89 |
+
def my_load_file_from_url(**kwargs):
|
90 |
+
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
91 |
+
|
92 |
+
def facex_load_file_from_url(**kwargs):
|
93 |
+
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
|
94 |
+
|
95 |
+
def facex_load_file_from_url2(**kwargs):
|
96 |
+
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
|
97 |
+
|
98 |
+
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
99 |
+
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
100 |
+
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
101 |
+
user_path = dirname
|
102 |
+
have_gfpgan = True
|
103 |
+
gfpgan_constructor = GFPGANer
|
104 |
+
|
105 |
+
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
106 |
+
def name(self):
|
107 |
+
return "GFPGAN"
|
108 |
+
|
109 |
+
def restore(self, np_image):
|
110 |
+
return gfpgan_fix_faces(np_image)
|
111 |
+
|
112 |
+
shared.face_restorers.append(FaceRestorerGFPGAN())
|
113 |
+
except Exception:
|
114 |
+
print("Error setting up GFPGAN:", file=sys.stderr)
|
115 |
+
print(traceback.format_exc(), file=sys.stderr)
|
modules/hypernetwork.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from ldm.util import default
|
9 |
+
from modules import devices, shared
|
10 |
+
import torch
|
11 |
+
from torch import einsum
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
|
14 |
+
|
15 |
+
class HypernetworkModule(torch.nn.Module):
|
16 |
+
def __init__(self, dim, state_dict):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.linear1 = torch.nn.Linear(dim, dim * 2)
|
20 |
+
self.linear2 = torch.nn.Linear(dim * 2, dim)
|
21 |
+
|
22 |
+
self.load_state_dict(state_dict, strict=True)
|
23 |
+
self.to(devices.device)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return x + (self.linear2(self.linear1(x)))
|
27 |
+
|
28 |
+
|
29 |
+
class Hypernetwork:
|
30 |
+
filename = None
|
31 |
+
name = None
|
32 |
+
|
33 |
+
def __init__(self, filename):
|
34 |
+
self.filename = filename
|
35 |
+
self.name = os.path.splitext(os.path.basename(filename))[0]
|
36 |
+
self.layers = {}
|
37 |
+
|
38 |
+
state_dict = torch.load(filename, map_location='cpu')
|
39 |
+
for size, sd in state_dict.items():
|
40 |
+
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
|
41 |
+
|
42 |
+
|
43 |
+
def load_hypernetworks(path):
|
44 |
+
res = {}
|
45 |
+
|
46 |
+
for filename in glob.iglob(path + '**/*.pt', recursive=True):
|
47 |
+
try:
|
48 |
+
hn = Hypernetwork(filename)
|
49 |
+
res[hn.name] = hn
|
50 |
+
except Exception:
|
51 |
+
print(f"Error loading hypernetwork {filename}", file=sys.stderr)
|
52 |
+
print(traceback.format_exc(), file=sys.stderr)
|
53 |
+
|
54 |
+
return res
|
55 |
+
|
56 |
+
|
57 |
+
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
58 |
+
h = self.heads
|
59 |
+
|
60 |
+
q = self.to_q(x)
|
61 |
+
context = default(context, x)
|
62 |
+
|
63 |
+
hypernetwork = shared.selected_hypernetwork()
|
64 |
+
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
65 |
+
|
66 |
+
if hypernetwork_layers is not None:
|
67 |
+
k = self.to_k(hypernetwork_layers[0](context))
|
68 |
+
v = self.to_v(hypernetwork_layers[1](context))
|
69 |
+
else:
|
70 |
+
k = self.to_k(context)
|
71 |
+
v = self.to_v(context)
|
72 |
+
|
73 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
74 |
+
|
75 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
76 |
+
|
77 |
+
if mask is not None:
|
78 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
79 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
80 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
81 |
+
sim.masked_fill_(~mask, max_neg_value)
|
82 |
+
|
83 |
+
# attention, what we cannot get enough of
|
84 |
+
attn = sim.softmax(dim=-1)
|
85 |
+
|
86 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
87 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
88 |
+
return self.to_out(out)
|
modules/images.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
from collections import namedtuple
|
5 |
+
import re
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import piexif
|
9 |
+
import piexif.helper
|
10 |
+
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
11 |
+
from fonts.ttf import Roboto
|
12 |
+
import string
|
13 |
+
|
14 |
+
from modules import sd_samplers, shared
|
15 |
+
from modules.shared import opts, cmd_opts
|
16 |
+
|
17 |
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
18 |
+
|
19 |
+
|
20 |
+
def image_grid(imgs, batch_size=1, rows=None):
|
21 |
+
if rows is None:
|
22 |
+
if opts.n_rows > 0:
|
23 |
+
rows = opts.n_rows
|
24 |
+
elif opts.n_rows == 0:
|
25 |
+
rows = batch_size
|
26 |
+
else:
|
27 |
+
rows = math.sqrt(len(imgs))
|
28 |
+
rows = round(rows)
|
29 |
+
|
30 |
+
cols = math.ceil(len(imgs) / rows)
|
31 |
+
|
32 |
+
w, h = imgs[0].size
|
33 |
+
grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
|
34 |
+
|
35 |
+
for i, img in enumerate(imgs):
|
36 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
37 |
+
|
38 |
+
return grid
|
39 |
+
|
40 |
+
|
41 |
+
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
42 |
+
|
43 |
+
|
44 |
+
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
45 |
+
w = image.width
|
46 |
+
h = image.height
|
47 |
+
|
48 |
+
non_overlap_width = tile_w - overlap
|
49 |
+
non_overlap_height = tile_h - overlap
|
50 |
+
|
51 |
+
cols = math.ceil((w - overlap) / non_overlap_width)
|
52 |
+
rows = math.ceil((h - overlap) / non_overlap_height)
|
53 |
+
|
54 |
+
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
55 |
+
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
56 |
+
|
57 |
+
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
58 |
+
for row in range(rows):
|
59 |
+
row_images = []
|
60 |
+
|
61 |
+
y = int(row * dy)
|
62 |
+
|
63 |
+
if y + tile_h >= h:
|
64 |
+
y = h - tile_h
|
65 |
+
|
66 |
+
for col in range(cols):
|
67 |
+
x = int(col * dx)
|
68 |
+
|
69 |
+
if x + tile_w >= w:
|
70 |
+
x = w - tile_w
|
71 |
+
|
72 |
+
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
73 |
+
|
74 |
+
row_images.append([x, tile_w, tile])
|
75 |
+
|
76 |
+
grid.tiles.append([y, tile_h, row_images])
|
77 |
+
|
78 |
+
return grid
|
79 |
+
|
80 |
+
|
81 |
+
def combine_grid(grid):
|
82 |
+
def make_mask_image(r):
|
83 |
+
r = r * 255 / grid.overlap
|
84 |
+
r = r.astype(np.uint8)
|
85 |
+
return Image.fromarray(r, 'L')
|
86 |
+
|
87 |
+
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
88 |
+
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
89 |
+
|
90 |
+
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
91 |
+
for y, h, row in grid.tiles:
|
92 |
+
combined_row = Image.new("RGB", (grid.image_w, h))
|
93 |
+
for x, w, tile in row:
|
94 |
+
if x == 0:
|
95 |
+
combined_row.paste(tile, (0, 0))
|
96 |
+
continue
|
97 |
+
|
98 |
+
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
|
99 |
+
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
|
100 |
+
|
101 |
+
if y == 0:
|
102 |
+
combined_image.paste(combined_row, (0, 0))
|
103 |
+
continue
|
104 |
+
|
105 |
+
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
|
106 |
+
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
|
107 |
+
|
108 |
+
return combined_image
|
109 |
+
|
110 |
+
|
111 |
+
class GridAnnotation:
|
112 |
+
def __init__(self, text='', is_active=True):
|
113 |
+
self.text = text
|
114 |
+
self.is_active = is_active
|
115 |
+
self.size = None
|
116 |
+
|
117 |
+
|
118 |
+
def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
119 |
+
def wrap(drawing, text, font, line_length):
|
120 |
+
lines = ['']
|
121 |
+
for word in text.split():
|
122 |
+
line = f'{lines[-1]} {word}'.strip()
|
123 |
+
if drawing.textlength(line, font=font) <= line_length:
|
124 |
+
lines[-1] = line
|
125 |
+
else:
|
126 |
+
lines.append(word)
|
127 |
+
return lines
|
128 |
+
|
129 |
+
def draw_texts(drawing, draw_x, draw_y, lines):
|
130 |
+
for i, line in enumerate(lines):
|
131 |
+
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
132 |
+
|
133 |
+
if not line.is_active:
|
134 |
+
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
135 |
+
|
136 |
+
draw_y += line.size[1] + line_spacing
|
137 |
+
|
138 |
+
fontsize = (width + height) // 25
|
139 |
+
line_spacing = fontsize // 2
|
140 |
+
|
141 |
+
try:
|
142 |
+
fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
|
143 |
+
except Exception:
|
144 |
+
fnt = ImageFont.truetype(Roboto, fontsize)
|
145 |
+
|
146 |
+
color_active = (0, 0, 0)
|
147 |
+
color_inactive = (153, 153, 153)
|
148 |
+
|
149 |
+
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
150 |
+
|
151 |
+
cols = im.width // width
|
152 |
+
rows = im.height // height
|
153 |
+
|
154 |
+
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
155 |
+
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
156 |
+
|
157 |
+
calc_img = Image.new("RGB", (1, 1), "white")
|
158 |
+
calc_d = ImageDraw.Draw(calc_img)
|
159 |
+
|
160 |
+
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
|
161 |
+
items = [] + texts
|
162 |
+
texts.clear()
|
163 |
+
|
164 |
+
for line in items:
|
165 |
+
wrapped = wrap(calc_d, line.text, fnt, allowed_width)
|
166 |
+
texts += [GridAnnotation(x, line.is_active) for x in wrapped]
|
167 |
+
|
168 |
+
for line in texts:
|
169 |
+
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
|
170 |
+
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
171 |
+
|
172 |
+
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
173 |
+
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
|
174 |
+
ver_texts]
|
175 |
+
|
176 |
+
pad_top = max(hor_text_heights) + line_spacing * 2
|
177 |
+
|
178 |
+
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
|
179 |
+
result.paste(im, (pad_left, pad_top))
|
180 |
+
|
181 |
+
d = ImageDraw.Draw(result)
|
182 |
+
|
183 |
+
for col in range(cols):
|
184 |
+
x = pad_left + width * col + width / 2
|
185 |
+
y = pad_top / 2 - hor_text_heights[col] / 2
|
186 |
+
|
187 |
+
draw_texts(d, x, y, hor_texts[col])
|
188 |
+
|
189 |
+
for row in range(rows):
|
190 |
+
x = pad_left / 2
|
191 |
+
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
|
192 |
+
|
193 |
+
draw_texts(d, x, y, ver_texts[row])
|
194 |
+
|
195 |
+
return result
|
196 |
+
|
197 |
+
|
198 |
+
def draw_prompt_matrix(im, width, height, all_prompts):
|
199 |
+
prompts = all_prompts[1:]
|
200 |
+
boundary = math.ceil(len(prompts) / 2)
|
201 |
+
|
202 |
+
prompts_horiz = prompts[:boundary]
|
203 |
+
prompts_vert = prompts[boundary:]
|
204 |
+
|
205 |
+
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
206 |
+
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
207 |
+
|
208 |
+
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
|
209 |
+
|
210 |
+
|
211 |
+
def resize_image(resize_mode, im, width, height):
|
212 |
+
def resize(im, w, h):
|
213 |
+
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
|
214 |
+
return im.resize((w, h), resample=LANCZOS)
|
215 |
+
|
216 |
+
scale = max(w / im.width, h / im.height)
|
217 |
+
|
218 |
+
if scale > 1.0:
|
219 |
+
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
|
220 |
+
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
|
221 |
+
|
222 |
+
upscaler = upscalers[0]
|
223 |
+
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
224 |
+
|
225 |
+
if im.width != w or im.height != h:
|
226 |
+
im = im.resize((w, h), resample=LANCZOS)
|
227 |
+
|
228 |
+
return im
|
229 |
+
|
230 |
+
if resize_mode == 0:
|
231 |
+
res = resize(im, width, height)
|
232 |
+
|
233 |
+
elif resize_mode == 1:
|
234 |
+
ratio = width / height
|
235 |
+
src_ratio = im.width / im.height
|
236 |
+
|
237 |
+
src_w = width if ratio > src_ratio else im.width * height // im.height
|
238 |
+
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
239 |
+
|
240 |
+
resized = resize(im, src_w, src_h)
|
241 |
+
res = Image.new("RGB", (width, height))
|
242 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
243 |
+
|
244 |
+
else:
|
245 |
+
ratio = width / height
|
246 |
+
src_ratio = im.width / im.height
|
247 |
+
|
248 |
+
src_w = width if ratio < src_ratio else im.width * height // im.height
|
249 |
+
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
250 |
+
|
251 |
+
resized = resize(im, src_w, src_h)
|
252 |
+
res = Image.new("RGB", (width, height))
|
253 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
254 |
+
|
255 |
+
if ratio < src_ratio:
|
256 |
+
fill_height = height // 2 - src_h // 2
|
257 |
+
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
258 |
+
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
259 |
+
elif ratio > src_ratio:
|
260 |
+
fill_width = width // 2 - src_w // 2
|
261 |
+
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
262 |
+
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
263 |
+
|
264 |
+
return res
|
265 |
+
|
266 |
+
|
267 |
+
invalid_filename_chars = '<>:"/\\|?*\n'
|
268 |
+
invalid_filename_prefix = ' '
|
269 |
+
invalid_filename_postfix = ' .'
|
270 |
+
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
271 |
+
max_filename_part_length = 128
|
272 |
+
|
273 |
+
|
274 |
+
def sanitize_filename_part(text, replace_spaces=True):
|
275 |
+
if replace_spaces:
|
276 |
+
text = text.replace(' ', '_')
|
277 |
+
|
278 |
+
text = text.translate({ord(x): '_' for x in invalid_filename_chars})
|
279 |
+
text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
|
280 |
+
text = text.rstrip(invalid_filename_postfix)
|
281 |
+
return text
|
282 |
+
|
283 |
+
|
284 |
+
def apply_filename_pattern(x, p, seed, prompt):
|
285 |
+
max_prompt_words = opts.directories_max_prompt_words
|
286 |
+
|
287 |
+
if seed is not None:
|
288 |
+
x = x.replace("[seed]", str(seed))
|
289 |
+
|
290 |
+
if p is not None:
|
291 |
+
x = x.replace("[steps]", str(p.steps))
|
292 |
+
x = x.replace("[cfg]", str(p.cfg_scale))
|
293 |
+
x = x.replace("[width]", str(p.width))
|
294 |
+
x = x.replace("[height]", str(p.height))
|
295 |
+
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
|
296 |
+
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
297 |
+
|
298 |
+
x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash))
|
299 |
+
x = x.replace("[date]", datetime.date.today().isoformat())
|
300 |
+
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
|
301 |
+
x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp))
|
302 |
+
|
303 |
+
# Apply [prompt] at last. Because it may contain any replacement word.^M
|
304 |
+
if prompt is not None:
|
305 |
+
x = x.replace("[prompt]", sanitize_filename_part(prompt))
|
306 |
+
if "[prompt_no_styles]" in x:
|
307 |
+
prompt_no_style = prompt
|
308 |
+
for style in shared.prompt_styles.get_style_prompts(p.styles):
|
309 |
+
if len(style) > 0:
|
310 |
+
style_parts = [y for y in style.split("{prompt}")]
|
311 |
+
for part in style_parts:
|
312 |
+
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
313 |
+
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
|
314 |
+
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
|
315 |
+
|
316 |
+
x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
|
317 |
+
if "[prompt_words]" in x:
|
318 |
+
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
|
319 |
+
if len(words) == 0:
|
320 |
+
words = ["empty"]
|
321 |
+
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
|
322 |
+
|
323 |
+
if cmd_opts.hide_ui_dir_config:
|
324 |
+
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
|
325 |
+
|
326 |
+
return x
|
327 |
+
|
328 |
+
|
329 |
+
def get_next_sequence_number(path, basename):
|
330 |
+
"""
|
331 |
+
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
332 |
+
|
333 |
+
The sequence starts at 0.
|
334 |
+
"""
|
335 |
+
result = -1
|
336 |
+
if basename != '':
|
337 |
+
basename = basename + "-"
|
338 |
+
|
339 |
+
prefix_length = len(basename)
|
340 |
+
for p in os.listdir(path):
|
341 |
+
if p.startswith(basename):
|
342 |
+
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
343 |
+
try:
|
344 |
+
result = max(int(l[0]), result)
|
345 |
+
except ValueError:
|
346 |
+
pass
|
347 |
+
|
348 |
+
return result + 1
|
349 |
+
|
350 |
+
|
351 |
+
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
352 |
+
if short_filename or prompt is None or seed is None:
|
353 |
+
file_decoration = ""
|
354 |
+
elif opts.save_to_dirs:
|
355 |
+
file_decoration = opts.samples_filename_pattern or "[seed]"
|
356 |
+
else:
|
357 |
+
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
358 |
+
|
359 |
+
if file_decoration != "":
|
360 |
+
file_decoration = "-" + file_decoration.lower()
|
361 |
+
|
362 |
+
file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
|
363 |
+
|
364 |
+
if extension == 'png' and opts.enable_pnginfo and info is not None:
|
365 |
+
pnginfo = PngImagePlugin.PngInfo()
|
366 |
+
|
367 |
+
if existing_info is not None:
|
368 |
+
for k, v in existing_info.items():
|
369 |
+
pnginfo.add_text(k, str(v))
|
370 |
+
|
371 |
+
pnginfo.add_text(pnginfo_section_name, info)
|
372 |
+
else:
|
373 |
+
pnginfo = None
|
374 |
+
|
375 |
+
if save_to_dirs is None:
|
376 |
+
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
377 |
+
|
378 |
+
if save_to_dirs:
|
379 |
+
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
|
380 |
+
path = os.path.join(path, dirname)
|
381 |
+
|
382 |
+
os.makedirs(path, exist_ok=True)
|
383 |
+
|
384 |
+
if forced_filename is None:
|
385 |
+
basecount = get_next_sequence_number(path, basename)
|
386 |
+
fullfn = "a.png"
|
387 |
+
fullfn_without_extension = "a"
|
388 |
+
for i in range(500):
|
389 |
+
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
390 |
+
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
391 |
+
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
|
392 |
+
if not os.path.exists(fullfn):
|
393 |
+
break
|
394 |
+
else:
|
395 |
+
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
|
396 |
+
fullfn_without_extension = os.path.join(path, forced_filename)
|
397 |
+
|
398 |
+
def exif_bytes():
|
399 |
+
return piexif.dump({
|
400 |
+
"Exif": {
|
401 |
+
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
|
402 |
+
},
|
403 |
+
})
|
404 |
+
|
405 |
+
if extension.lower() in ("jpg", "jpeg", "webp"):
|
406 |
+
image.save(fullfn, quality=opts.jpeg_quality)
|
407 |
+
if opts.enable_pnginfo and info is not None:
|
408 |
+
piexif.insert(exif_bytes(), fullfn)
|
409 |
+
else:
|
410 |
+
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
|
411 |
+
|
412 |
+
target_side_length = 4000
|
413 |
+
oversize = image.width > target_side_length or image.height > target_side_length
|
414 |
+
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
|
415 |
+
ratio = image.width / image.height
|
416 |
+
|
417 |
+
if oversize and ratio > 1:
|
418 |
+
image = image.resize((target_side_length, image.height * target_side_length // image.width), LANCZOS)
|
419 |
+
elif oversize:
|
420 |
+
image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
|
421 |
+
|
422 |
+
image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality)
|
423 |
+
if opts.enable_pnginfo and info is not None:
|
424 |
+
piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
|
425 |
+
|
426 |
+
if opts.save_txt and info is not None:
|
427 |
+
with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
|
428 |
+
file.write(info + "\n")
|
429 |
+
|
430 |
+
return fullfn
|
modules/img2img.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image, ImageOps, ImageChops
|
8 |
+
|
9 |
+
from modules import devices
|
10 |
+
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
11 |
+
from modules.shared import opts, state
|
12 |
+
import modules.shared as shared
|
13 |
+
import modules.processing as processing
|
14 |
+
from modules.ui import plaintext_to_html
|
15 |
+
import modules.images as images
|
16 |
+
import modules.scripts
|
17 |
+
|
18 |
+
|
19 |
+
def process_batch(p, input_dir, output_dir, args):
|
20 |
+
processing.fix_seed(p)
|
21 |
+
|
22 |
+
images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
|
23 |
+
|
24 |
+
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
25 |
+
|
26 |
+
save_normally = output_dir == ''
|
27 |
+
|
28 |
+
p.do_not_save_grid = True
|
29 |
+
p.do_not_save_samples = not save_normally
|
30 |
+
|
31 |
+
state.job_count = len(images) * p.n_iter
|
32 |
+
|
33 |
+
for i, image in enumerate(images):
|
34 |
+
state.job = f"{i+1} out of {len(images)}"
|
35 |
+
|
36 |
+
if state.interrupted:
|
37 |
+
break
|
38 |
+
|
39 |
+
img = Image.open(image)
|
40 |
+
p.init_images = [img] * p.batch_size
|
41 |
+
|
42 |
+
proc = modules.scripts.scripts_img2img.run(p, *args)
|
43 |
+
if proc is None:
|
44 |
+
proc = process_images(p)
|
45 |
+
|
46 |
+
for n, processed_image in enumerate(proc.images):
|
47 |
+
filename = os.path.basename(image)
|
48 |
+
|
49 |
+
if n > 0:
|
50 |
+
left, right = os.path.splitext(filename)
|
51 |
+
filename = f"{left}-{n}{right}"
|
52 |
+
|
53 |
+
if not save_normally:
|
54 |
+
processed_image.save(os.path.join(output_dir, filename))
|
55 |
+
|
56 |
+
|
57 |
+
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
58 |
+
is_inpaint = mode == 1
|
59 |
+
is_batch = mode == 2
|
60 |
+
|
61 |
+
if is_inpaint:
|
62 |
+
if mask_mode == 0:
|
63 |
+
image = init_img_with_mask['image']
|
64 |
+
mask = init_img_with_mask['mask']
|
65 |
+
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
66 |
+
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
67 |
+
image = image.convert('RGB')
|
68 |
+
else:
|
69 |
+
image = init_img_inpaint
|
70 |
+
mask = init_mask_inpaint
|
71 |
+
else:
|
72 |
+
image = init_img
|
73 |
+
mask = None
|
74 |
+
|
75 |
+
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
76 |
+
|
77 |
+
p = StableDiffusionProcessingImg2Img(
|
78 |
+
sd_model=shared.sd_model,
|
79 |
+
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
|
80 |
+
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
81 |
+
prompt=prompt,
|
82 |
+
negative_prompt=negative_prompt,
|
83 |
+
styles=[prompt_style, prompt_style2],
|
84 |
+
seed=seed,
|
85 |
+
subseed=subseed,
|
86 |
+
subseed_strength=subseed_strength,
|
87 |
+
seed_resize_from_h=seed_resize_from_h,
|
88 |
+
seed_resize_from_w=seed_resize_from_w,
|
89 |
+
seed_enable_extras=seed_enable_extras,
|
90 |
+
sampler_index=sampler_index,
|
91 |
+
batch_size=batch_size,
|
92 |
+
n_iter=n_iter,
|
93 |
+
steps=steps,
|
94 |
+
cfg_scale=cfg_scale,
|
95 |
+
width=width,
|
96 |
+
height=height,
|
97 |
+
restore_faces=restore_faces,
|
98 |
+
tiling=tiling,
|
99 |
+
init_images=[image],
|
100 |
+
mask=mask,
|
101 |
+
mask_blur=mask_blur,
|
102 |
+
inpainting_fill=inpainting_fill,
|
103 |
+
resize_mode=resize_mode,
|
104 |
+
denoising_strength=denoising_strength,
|
105 |
+
inpaint_full_res=inpaint_full_res,
|
106 |
+
inpaint_full_res_padding=inpaint_full_res_padding,
|
107 |
+
inpainting_mask_invert=inpainting_mask_invert,
|
108 |
+
)
|
109 |
+
|
110 |
+
if shared.cmd_opts.enable_console_prompts:
|
111 |
+
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
112 |
+
|
113 |
+
p.extra_generation_params["Mask blur"] = mask_blur
|
114 |
+
|
115 |
+
if is_batch:
|
116 |
+
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
117 |
+
|
118 |
+
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args)
|
119 |
+
|
120 |
+
processed = Processed(p, [], p.seed, "")
|
121 |
+
else:
|
122 |
+
processed = modules.scripts.scripts_img2img.run(p, *args)
|
123 |
+
if processed is None:
|
124 |
+
processed = process_images(p)
|
125 |
+
|
126 |
+
shared.total_tqdm.clear()
|
127 |
+
|
128 |
+
generation_info_js = processed.js()
|
129 |
+
if opts.samples_log_stdout:
|
130 |
+
print(generation_info_js)
|
131 |
+
|
132 |
+
if opts.do_not_show_images:
|
133 |
+
processed.images = []
|
134 |
+
|
135 |
+
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
modules/interrogate.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
from collections import namedtuple
|
6 |
+
import re
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from torchvision import transforms
|
11 |
+
from torchvision.transforms.functional import InterpolationMode
|
12 |
+
|
13 |
+
import modules.shared as shared
|
14 |
+
from modules import devices, paths, lowvram
|
15 |
+
|
16 |
+
blip_image_eval_size = 384
|
17 |
+
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
18 |
+
clip_model_name = 'ViT-L/14'
|
19 |
+
|
20 |
+
Category = namedtuple("Category", ["name", "topn", "items"])
|
21 |
+
|
22 |
+
re_topn = re.compile(r"\.top(\d+)\.")
|
23 |
+
|
24 |
+
|
25 |
+
class InterrogateModels:
|
26 |
+
blip_model = None
|
27 |
+
clip_model = None
|
28 |
+
clip_preprocess = None
|
29 |
+
categories = None
|
30 |
+
dtype = None
|
31 |
+
|
32 |
+
def __init__(self, content_dir):
|
33 |
+
self.categories = []
|
34 |
+
|
35 |
+
if os.path.exists(content_dir):
|
36 |
+
for filename in os.listdir(content_dir):
|
37 |
+
m = re_topn.search(filename)
|
38 |
+
topn = 1 if m is None else int(m.group(1))
|
39 |
+
|
40 |
+
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
|
41 |
+
lines = [x.strip() for x in file.readlines()]
|
42 |
+
|
43 |
+
self.categories.append(Category(name=filename, topn=topn, items=lines))
|
44 |
+
|
45 |
+
def load_blip_model(self):
|
46 |
+
import models.blip
|
47 |
+
|
48 |
+
blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
49 |
+
blip_model.eval()
|
50 |
+
|
51 |
+
return blip_model
|
52 |
+
|
53 |
+
def load_clip_model(self):
|
54 |
+
import clip
|
55 |
+
|
56 |
+
model, preprocess = clip.load(clip_model_name)
|
57 |
+
model.eval()
|
58 |
+
model = model.to(shared.device)
|
59 |
+
|
60 |
+
return model, preprocess
|
61 |
+
|
62 |
+
def load(self):
|
63 |
+
if self.blip_model is None:
|
64 |
+
self.blip_model = self.load_blip_model()
|
65 |
+
if not shared.cmd_opts.no_half:
|
66 |
+
self.blip_model = self.blip_model.half()
|
67 |
+
|
68 |
+
self.blip_model = self.blip_model.to(shared.device)
|
69 |
+
|
70 |
+
if self.clip_model is None:
|
71 |
+
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
72 |
+
if not shared.cmd_opts.no_half:
|
73 |
+
self.clip_model = self.clip_model.half()
|
74 |
+
|
75 |
+
self.clip_model = self.clip_model.to(shared.device)
|
76 |
+
|
77 |
+
self.dtype = next(self.clip_model.parameters()).dtype
|
78 |
+
|
79 |
+
def send_clip_to_ram(self):
|
80 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
81 |
+
if self.clip_model is not None:
|
82 |
+
self.clip_model = self.clip_model.to(devices.cpu)
|
83 |
+
|
84 |
+
def send_blip_to_ram(self):
|
85 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
86 |
+
if self.blip_model is not None:
|
87 |
+
self.blip_model = self.blip_model.to(devices.cpu)
|
88 |
+
|
89 |
+
def unload(self):
|
90 |
+
self.send_clip_to_ram()
|
91 |
+
self.send_blip_to_ram()
|
92 |
+
|
93 |
+
devices.torch_gc()
|
94 |
+
|
95 |
+
def rank(self, image_features, text_array, top_count=1):
|
96 |
+
import clip
|
97 |
+
|
98 |
+
if shared.opts.interrogate_clip_dict_limit != 0:
|
99 |
+
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
100 |
+
|
101 |
+
top_count = min(top_count, len(text_array))
|
102 |
+
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
|
103 |
+
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
104 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
105 |
+
|
106 |
+
similarity = torch.zeros((1, len(text_array))).to(shared.device)
|
107 |
+
for i in range(image_features.shape[0]):
|
108 |
+
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
109 |
+
similarity /= image_features.shape[0]
|
110 |
+
|
111 |
+
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
112 |
+
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
113 |
+
|
114 |
+
def generate_caption(self, pil_image):
|
115 |
+
gpu_image = transforms.Compose([
|
116 |
+
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
117 |
+
transforms.ToTensor(),
|
118 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
119 |
+
])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
|
120 |
+
|
121 |
+
with torch.no_grad():
|
122 |
+
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
123 |
+
|
124 |
+
return caption[0]
|
125 |
+
|
126 |
+
def interrogate(self, pil_image):
|
127 |
+
res = None
|
128 |
+
|
129 |
+
try:
|
130 |
+
|
131 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
132 |
+
lowvram.send_everything_to_cpu()
|
133 |
+
devices.torch_gc()
|
134 |
+
|
135 |
+
self.load()
|
136 |
+
|
137 |
+
caption = self.generate_caption(pil_image)
|
138 |
+
self.send_blip_to_ram()
|
139 |
+
devices.torch_gc()
|
140 |
+
|
141 |
+
res = caption
|
142 |
+
|
143 |
+
cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
|
144 |
+
|
145 |
+
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
|
146 |
+
with torch.no_grad(), precision_scope("cuda"):
|
147 |
+
image_features = self.clip_model.encode_image(cilp_image).type(self.dtype)
|
148 |
+
|
149 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
150 |
+
|
151 |
+
if shared.opts.interrogate_use_builtin_artists:
|
152 |
+
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
|
153 |
+
|
154 |
+
res += ", " + artist[0]
|
155 |
+
|
156 |
+
for name, topn, items in self.categories:
|
157 |
+
matches = self.rank(image_features, items, top_count=topn)
|
158 |
+
for match, score in matches:
|
159 |
+
res += ", " + match
|
160 |
+
|
161 |
+
except Exception:
|
162 |
+
print(f"Error interrogating", file=sys.stderr)
|
163 |
+
print(traceback.format_exc(), file=sys.stderr)
|
164 |
+
res += "<error>"
|
165 |
+
|
166 |
+
self.unload()
|
167 |
+
|
168 |
+
return res
|
modules/ldsr_model.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
from basicsr.utils.download_util import load_file_from_url
|
6 |
+
|
7 |
+
from modules.upscaler import Upscaler, UpscalerData
|
8 |
+
from modules.ldsr_model_arch import LDSR
|
9 |
+
from modules import shared
|
10 |
+
from modules.paths import models_path
|
11 |
+
|
12 |
+
|
13 |
+
class UpscalerLDSR(Upscaler):
|
14 |
+
def __init__(self, user_path):
|
15 |
+
self.name = "LDSR"
|
16 |
+
self.model_path = os.path.join(models_path, self.name)
|
17 |
+
self.user_path = user_path
|
18 |
+
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
19 |
+
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
20 |
+
super().__init__()
|
21 |
+
scaler_data = UpscalerData("LDSR", None, self)
|
22 |
+
self.scalers = [scaler_data]
|
23 |
+
|
24 |
+
def load_model(self, path: str):
|
25 |
+
# Remove incorrect project.yaml file if too big
|
26 |
+
yaml_path = os.path.join(self.model_path, "project.yaml")
|
27 |
+
old_model_path = os.path.join(self.model_path, "model.pth")
|
28 |
+
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
29 |
+
if os.path.exists(yaml_path):
|
30 |
+
statinfo = os.stat(yaml_path)
|
31 |
+
if statinfo.st_size >= 10485760:
|
32 |
+
print("Removing invalid LDSR YAML file.")
|
33 |
+
os.remove(yaml_path)
|
34 |
+
if os.path.exists(old_model_path):
|
35 |
+
print("Renaming model from model.pth to model.ckpt")
|
36 |
+
os.rename(old_model_path, new_model_path)
|
37 |
+
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
38 |
+
file_name="model.ckpt", progress=True)
|
39 |
+
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
|
40 |
+
file_name="project.yaml", progress=True)
|
41 |
+
|
42 |
+
try:
|
43 |
+
return LDSR(model, yaml)
|
44 |
+
|
45 |
+
except Exception:
|
46 |
+
print("Error importing LDSR:", file=sys.stderr)
|
47 |
+
print(traceback.format_exc(), file=sys.stderr)
|
48 |
+
return None
|
49 |
+
|
50 |
+
def do_upscale(self, img, path):
|
51 |
+
ldsr = self.load_model(path)
|
52 |
+
if ldsr is None:
|
53 |
+
print("NO LDSR!")
|
54 |
+
return img
|
55 |
+
ddim_steps = shared.opts.ldsr_steps
|
56 |
+
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
modules/ldsr_model_arch.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import time
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
from PIL import Image
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
|
12 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
13 |
+
from ldm.util import instantiate_from_config, ismap
|
14 |
+
|
15 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
16 |
+
|
17 |
+
|
18 |
+
# Create LDSR Class
|
19 |
+
class LDSR:
|
20 |
+
def load_model_from_config(self, half_attention):
|
21 |
+
print(f"Loading model from {self.modelPath}")
|
22 |
+
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
23 |
+
sd = pl_sd["state_dict"]
|
24 |
+
config = OmegaConf.load(self.yamlPath)
|
25 |
+
model = instantiate_from_config(config.model)
|
26 |
+
model.load_state_dict(sd, strict=False)
|
27 |
+
model.cuda()
|
28 |
+
if half_attention:
|
29 |
+
model = model.half()
|
30 |
+
|
31 |
+
model.eval()
|
32 |
+
return {"model": model}
|
33 |
+
|
34 |
+
def __init__(self, model_path, yaml_path):
|
35 |
+
self.modelPath = model_path
|
36 |
+
self.yamlPath = yaml_path
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def run(model, selected_path, custom_steps, eta):
|
40 |
+
example = get_cond(selected_path)
|
41 |
+
|
42 |
+
n_runs = 1
|
43 |
+
guider = None
|
44 |
+
ckwargs = None
|
45 |
+
ddim_use_x0_pred = False
|
46 |
+
temperature = 1.
|
47 |
+
eta = eta
|
48 |
+
custom_shape = None
|
49 |
+
|
50 |
+
height, width = example["image"].shape[1:3]
|
51 |
+
split_input = height >= 128 and width >= 128
|
52 |
+
|
53 |
+
if split_input:
|
54 |
+
ks = 128
|
55 |
+
stride = 64
|
56 |
+
vqf = 4 #
|
57 |
+
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
|
58 |
+
"vqf": vqf,
|
59 |
+
"patch_distributed_vq": True,
|
60 |
+
"tie_braker": False,
|
61 |
+
"clip_max_weight": 0.5,
|
62 |
+
"clip_min_weight": 0.01,
|
63 |
+
"clip_max_tie_weight": 0.5,
|
64 |
+
"clip_min_tie_weight": 0.01}
|
65 |
+
else:
|
66 |
+
if hasattr(model, "split_input_params"):
|
67 |
+
delattr(model, "split_input_params")
|
68 |
+
|
69 |
+
x_t = None
|
70 |
+
logs = None
|
71 |
+
for n in range(n_runs):
|
72 |
+
if custom_shape is not None:
|
73 |
+
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
74 |
+
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
75 |
+
|
76 |
+
logs = make_convolutional_sample(example, model,
|
77 |
+
custom_steps=custom_steps,
|
78 |
+
eta=eta, quantize_x0=False,
|
79 |
+
custom_shape=custom_shape,
|
80 |
+
temperature=temperature, noise_dropout=0.,
|
81 |
+
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
|
82 |
+
ddim_use_x0_pred=ddim_use_x0_pred
|
83 |
+
)
|
84 |
+
return logs
|
85 |
+
|
86 |
+
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
|
87 |
+
model = self.load_model_from_config(half_attention)
|
88 |
+
|
89 |
+
# Run settings
|
90 |
+
diffusion_steps = int(steps)
|
91 |
+
eta = 1.0
|
92 |
+
|
93 |
+
down_sample_method = 'Lanczos'
|
94 |
+
|
95 |
+
gc.collect()
|
96 |
+
torch.cuda.empty_cache()
|
97 |
+
|
98 |
+
im_og = image
|
99 |
+
width_og, height_og = im_og.size
|
100 |
+
# If we can adjust the max upscale size, then the 4 below should be our variable
|
101 |
+
down_sample_rate = target_scale / 4
|
102 |
+
wd = width_og * down_sample_rate
|
103 |
+
hd = height_og * down_sample_rate
|
104 |
+
width_downsampled_pre = int(wd)
|
105 |
+
height_downsampled_pre = int(hd)
|
106 |
+
|
107 |
+
if down_sample_rate != 1:
|
108 |
+
print(
|
109 |
+
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
110 |
+
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
111 |
+
else:
|
112 |
+
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
113 |
+
logs = self.run(model["model"], im_og, diffusion_steps, eta)
|
114 |
+
|
115 |
+
sample = logs["sample"]
|
116 |
+
sample = sample.detach().cpu()
|
117 |
+
sample = torch.clamp(sample, -1., 1.)
|
118 |
+
sample = (sample + 1.) / 2. * 255
|
119 |
+
sample = sample.numpy().astype(np.uint8)
|
120 |
+
sample = np.transpose(sample, (0, 2, 3, 1))
|
121 |
+
a = Image.fromarray(sample[0])
|
122 |
+
|
123 |
+
del model
|
124 |
+
gc.collect()
|
125 |
+
torch.cuda.empty_cache()
|
126 |
+
return a
|
127 |
+
|
128 |
+
|
129 |
+
def get_cond(selected_path):
|
130 |
+
example = dict()
|
131 |
+
up_f = 4
|
132 |
+
c = selected_path.convert('RGB')
|
133 |
+
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
134 |
+
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
|
135 |
+
antialias=True)
|
136 |
+
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
|
137 |
+
c = rearrange(c, '1 c h w -> 1 h w c')
|
138 |
+
c = 2. * c - 1.
|
139 |
+
|
140 |
+
c = c.to(torch.device("cuda"))
|
141 |
+
example["LR_image"] = c
|
142 |
+
example["image"] = c_up
|
143 |
+
|
144 |
+
return example
|
145 |
+
|
146 |
+
|
147 |
+
@torch.no_grad()
|
148 |
+
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
|
149 |
+
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
|
150 |
+
corrector_kwargs=None, x_t=None
|
151 |
+
):
|
152 |
+
ddim = DDIMSampler(model)
|
153 |
+
bs = shape[0]
|
154 |
+
shape = shape[1:]
|
155 |
+
print(f"Sampling with eta = {eta}; steps: {steps}")
|
156 |
+
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
|
157 |
+
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
|
158 |
+
mask=mask, x0=x0, temperature=temperature, verbose=False,
|
159 |
+
score_corrector=score_corrector,
|
160 |
+
corrector_kwargs=corrector_kwargs, x_t=x_t)
|
161 |
+
|
162 |
+
return samples, intermediates
|
163 |
+
|
164 |
+
|
165 |
+
@torch.no_grad()
|
166 |
+
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
167 |
+
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
168 |
+
log = dict()
|
169 |
+
|
170 |
+
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
171 |
+
return_first_stage_outputs=True,
|
172 |
+
force_c_encode=not (hasattr(model, 'split_input_params')
|
173 |
+
and model.cond_stage_key == 'coordinates_bbox'),
|
174 |
+
return_original_cond=True)
|
175 |
+
|
176 |
+
if custom_shape is not None:
|
177 |
+
z = torch.randn(custom_shape)
|
178 |
+
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
179 |
+
|
180 |
+
z0 = None
|
181 |
+
|
182 |
+
log["input"] = x
|
183 |
+
log["reconstruction"] = xrec
|
184 |
+
|
185 |
+
if ismap(xc):
|
186 |
+
log["original_conditioning"] = model.to_rgb(xc)
|
187 |
+
if hasattr(model, 'cond_stage_key'):
|
188 |
+
log[model.cond_stage_key] = model.to_rgb(xc)
|
189 |
+
|
190 |
+
else:
|
191 |
+
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
192 |
+
if model.cond_stage_model:
|
193 |
+
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
194 |
+
if model.cond_stage_key == 'class_label':
|
195 |
+
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
196 |
+
|
197 |
+
with model.ema_scope("Plotting"):
|
198 |
+
t0 = time.time()
|
199 |
+
|
200 |
+
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
|
201 |
+
eta=eta,
|
202 |
+
quantize_x0=quantize_x0, mask=None, x0=z0,
|
203 |
+
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
|
204 |
+
x_t=x_T)
|
205 |
+
t1 = time.time()
|
206 |
+
|
207 |
+
if ddim_use_x0_pred:
|
208 |
+
sample = intermediates['pred_x0'][-1]
|
209 |
+
|
210 |
+
x_sample = model.decode_first_stage(sample)
|
211 |
+
|
212 |
+
try:
|
213 |
+
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
214 |
+
log["sample_noquant"] = x_sample_noquant
|
215 |
+
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
216 |
+
except:
|
217 |
+
pass
|
218 |
+
|
219 |
+
log["sample"] = x_sample
|
220 |
+
log["time"] = t1 - t0
|
221 |
+
|
222 |
+
return log
|
modules/lowvram.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from modules.devices import get_optimal_device
|
3 |
+
|
4 |
+
module_in_gpu = None
|
5 |
+
cpu = torch.device("cpu")
|
6 |
+
device = gpu = get_optimal_device()
|
7 |
+
|
8 |
+
|
9 |
+
def send_everything_to_cpu():
|
10 |
+
global module_in_gpu
|
11 |
+
|
12 |
+
if module_in_gpu is not None:
|
13 |
+
module_in_gpu.to(cpu)
|
14 |
+
|
15 |
+
module_in_gpu = None
|
16 |
+
|
17 |
+
|
18 |
+
def setup_for_low_vram(sd_model, use_medvram):
|
19 |
+
parents = {}
|
20 |
+
|
21 |
+
def send_me_to_gpu(module, _):
|
22 |
+
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
|
23 |
+
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
|
24 |
+
be in CPU
|
25 |
+
"""
|
26 |
+
global module_in_gpu
|
27 |
+
|
28 |
+
module = parents.get(module, module)
|
29 |
+
|
30 |
+
if module_in_gpu == module:
|
31 |
+
return
|
32 |
+
|
33 |
+
if module_in_gpu is not None:
|
34 |
+
module_in_gpu.to(cpu)
|
35 |
+
|
36 |
+
module.to(gpu)
|
37 |
+
module_in_gpu = module
|
38 |
+
|
39 |
+
# see below for register_forward_pre_hook;
|
40 |
+
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
41 |
+
# useless here, and we just replace those methods
|
42 |
+
def first_stage_model_encode_wrap(self, encoder, x):
|
43 |
+
send_me_to_gpu(self, None)
|
44 |
+
return encoder(x)
|
45 |
+
|
46 |
+
def first_stage_model_decode_wrap(self, decoder, z):
|
47 |
+
send_me_to_gpu(self, None)
|
48 |
+
return decoder(z)
|
49 |
+
|
50 |
+
# remove three big modules, cond, first_stage, and unet from the model and then
|
51 |
+
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
52 |
+
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
53 |
+
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
|
54 |
+
sd_model.to(device)
|
55 |
+
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
|
56 |
+
|
57 |
+
# register hooks for those the first two models
|
58 |
+
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
59 |
+
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
60 |
+
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
|
61 |
+
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
|
62 |
+
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
63 |
+
|
64 |
+
if use_medvram:
|
65 |
+
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
66 |
+
else:
|
67 |
+
diff_model = sd_model.model.diffusion_model
|
68 |
+
|
69 |
+
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
|
70 |
+
# so that only one of them is in GPU at a time
|
71 |
+
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
72 |
+
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
73 |
+
sd_model.model.to(device)
|
74 |
+
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
75 |
+
|
76 |
+
# install hooks for bits of third model
|
77 |
+
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
78 |
+
for block in diff_model.input_blocks:
|
79 |
+
block.register_forward_pre_hook(send_me_to_gpu)
|
80 |
+
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
81 |
+
for block in diff_model.output_blocks:
|
82 |
+
block.register_forward_pre_hook(send_me_to_gpu)
|
modules/masking.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageFilter, ImageOps
|
2 |
+
|
3 |
+
|
4 |
+
def get_crop_region(mask, pad=0):
|
5 |
+
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
6 |
+
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
7 |
+
|
8 |
+
h, w = mask.shape
|
9 |
+
|
10 |
+
crop_left = 0
|
11 |
+
for i in range(w):
|
12 |
+
if not (mask[:, i] == 0).all():
|
13 |
+
break
|
14 |
+
crop_left += 1
|
15 |
+
|
16 |
+
crop_right = 0
|
17 |
+
for i in reversed(range(w)):
|
18 |
+
if not (mask[:, i] == 0).all():
|
19 |
+
break
|
20 |
+
crop_right += 1
|
21 |
+
|
22 |
+
crop_top = 0
|
23 |
+
for i in range(h):
|
24 |
+
if not (mask[i] == 0).all():
|
25 |
+
break
|
26 |
+
crop_top += 1
|
27 |
+
|
28 |
+
crop_bottom = 0
|
29 |
+
for i in reversed(range(h)):
|
30 |
+
if not (mask[i] == 0).all():
|
31 |
+
break
|
32 |
+
crop_bottom += 1
|
33 |
+
|
34 |
+
return (
|
35 |
+
int(max(crop_left-pad, 0)),
|
36 |
+
int(max(crop_top-pad, 0)),
|
37 |
+
int(min(w - crop_right + pad, w)),
|
38 |
+
int(min(h - crop_bottom + pad, h))
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
|
43 |
+
"""expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
|
44 |
+
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
|
45 |
+
|
46 |
+
x1, y1, x2, y2 = crop_region
|
47 |
+
|
48 |
+
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
49 |
+
ratio_processing = processing_width / processing_height
|
50 |
+
|
51 |
+
if ratio_crop_region > ratio_processing:
|
52 |
+
desired_height = (x2 - x1) * ratio_processing
|
53 |
+
desired_height_diff = int(desired_height - (y2-y1))
|
54 |
+
y1 -= desired_height_diff//2
|
55 |
+
y2 += desired_height_diff - desired_height_diff//2
|
56 |
+
if y2 >= image_height:
|
57 |
+
diff = y2 - image_height
|
58 |
+
y2 -= diff
|
59 |
+
y1 -= diff
|
60 |
+
if y1 < 0:
|
61 |
+
y2 -= y1
|
62 |
+
y1 -= y1
|
63 |
+
if y2 >= image_height:
|
64 |
+
y2 = image_height
|
65 |
+
else:
|
66 |
+
desired_width = (y2 - y1) * ratio_processing
|
67 |
+
desired_width_diff = int(desired_width - (x2-x1))
|
68 |
+
x1 -= desired_width_diff//2
|
69 |
+
x2 += desired_width_diff - desired_width_diff//2
|
70 |
+
if x2 >= image_width:
|
71 |
+
diff = x2 - image_width
|
72 |
+
x2 -= diff
|
73 |
+
x1 -= diff
|
74 |
+
if x1 < 0:
|
75 |
+
x2 -= x1
|
76 |
+
x1 -= x1
|
77 |
+
if x2 >= image_width:
|
78 |
+
x2 = image_width
|
79 |
+
|
80 |
+
return x1, y1, x2, y2
|
81 |
+
|
82 |
+
|
83 |
+
def fill(image, mask):
|
84 |
+
"""fills masked regions with colors from image using blur. Not extremely effective."""
|
85 |
+
|
86 |
+
image_mod = Image.new('RGBA', (image.width, image.height))
|
87 |
+
|
88 |
+
image_masked = Image.new('RGBa', (image.width, image.height))
|
89 |
+
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
|
90 |
+
|
91 |
+
image_masked = image_masked.convert('RGBa')
|
92 |
+
|
93 |
+
for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
|
94 |
+
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
|
95 |
+
for _ in range(repeats):
|
96 |
+
image_mod.alpha_composite(blurred)
|
97 |
+
|
98 |
+
return image_mod.convert("RGB")
|
99 |
+
|
modules/memmon.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import time
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class MemUsageMonitor(threading.Thread):
|
9 |
+
run_flag = None
|
10 |
+
device = None
|
11 |
+
disabled = False
|
12 |
+
opts = None
|
13 |
+
data = None
|
14 |
+
|
15 |
+
def __init__(self, name, device, opts):
|
16 |
+
threading.Thread.__init__(self)
|
17 |
+
self.name = name
|
18 |
+
self.device = device
|
19 |
+
self.opts = opts
|
20 |
+
|
21 |
+
self.daemon = True
|
22 |
+
self.run_flag = threading.Event()
|
23 |
+
self.data = defaultdict(int)
|
24 |
+
|
25 |
+
try:
|
26 |
+
torch.cuda.mem_get_info()
|
27 |
+
torch.cuda.memory_stats(self.device)
|
28 |
+
except Exception as e: # AMD or whatever
|
29 |
+
print(f"Warning: caught exception '{e}', memory monitor disabled")
|
30 |
+
self.disabled = True
|
31 |
+
|
32 |
+
def run(self):
|
33 |
+
if self.disabled:
|
34 |
+
return
|
35 |
+
|
36 |
+
while True:
|
37 |
+
self.run_flag.wait()
|
38 |
+
|
39 |
+
torch.cuda.reset_peak_memory_stats()
|
40 |
+
self.data.clear()
|
41 |
+
|
42 |
+
if self.opts.memmon_poll_rate <= 0:
|
43 |
+
self.run_flag.clear()
|
44 |
+
continue
|
45 |
+
|
46 |
+
self.data["min_free"] = torch.cuda.mem_get_info()[0]
|
47 |
+
|
48 |
+
while self.run_flag.is_set():
|
49 |
+
free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
|
50 |
+
self.data["min_free"] = min(self.data["min_free"], free)
|
51 |
+
|
52 |
+
time.sleep(1 / self.opts.memmon_poll_rate)
|
53 |
+
|
54 |
+
def dump_debug(self):
|
55 |
+
print(self, 'recorded data:')
|
56 |
+
for k, v in self.read().items():
|
57 |
+
print(k, -(v // -(1024 ** 2)))
|
58 |
+
|
59 |
+
print(self, 'raw torch memory stats:')
|
60 |
+
tm = torch.cuda.memory_stats(self.device)
|
61 |
+
for k, v in tm.items():
|
62 |
+
if 'bytes' not in k:
|
63 |
+
continue
|
64 |
+
print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
|
65 |
+
|
66 |
+
print(torch.cuda.memory_summary())
|
67 |
+
|
68 |
+
def monitor(self):
|
69 |
+
self.run_flag.set()
|
70 |
+
|
71 |
+
def read(self):
|
72 |
+
if not self.disabled:
|
73 |
+
free, total = torch.cuda.mem_get_info()
|
74 |
+
self.data["total"] = total
|
75 |
+
|
76 |
+
torch_stats = torch.cuda.memory_stats(self.device)
|
77 |
+
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
78 |
+
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
79 |
+
self.data["system_peak"] = total - self.data["min_free"]
|
80 |
+
|
81 |
+
return self.data
|
82 |
+
|
83 |
+
def stop(self):
|
84 |
+
self.run_flag.clear()
|
85 |
+
return self.read()
|
modules/modelloader.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import importlib
|
5 |
+
from urllib.parse import urlparse
|
6 |
+
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
from modules import shared
|
9 |
+
from modules.upscaler import Upscaler
|
10 |
+
from modules.paths import script_path, models_path
|
11 |
+
|
12 |
+
|
13 |
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
|
14 |
+
"""
|
15 |
+
A one-and done loader to try finding the desired models in specified directories.
|
16 |
+
|
17 |
+
@param download_name: Specify to download from model_url immediately.
|
18 |
+
@param model_url: If no other models are found, this will be downloaded on upscale.
|
19 |
+
@param model_path: The location to store/find models in.
|
20 |
+
@param command_path: A command-line argument to search for models in first.
|
21 |
+
@param ext_filter: An optional list of filename extensions to filter by
|
22 |
+
@return: A list of paths containing the desired model(s)
|
23 |
+
"""
|
24 |
+
output = []
|
25 |
+
|
26 |
+
if ext_filter is None:
|
27 |
+
ext_filter = []
|
28 |
+
|
29 |
+
try:
|
30 |
+
places = []
|
31 |
+
|
32 |
+
if command_path is not None and command_path != model_path:
|
33 |
+
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
|
34 |
+
if os.path.exists(pretrained_path):
|
35 |
+
print(f"Appending path: {pretrained_path}")
|
36 |
+
places.append(pretrained_path)
|
37 |
+
elif os.path.exists(command_path):
|
38 |
+
places.append(command_path)
|
39 |
+
|
40 |
+
places.append(model_path)
|
41 |
+
|
42 |
+
for place in places:
|
43 |
+
if os.path.exists(place):
|
44 |
+
for file in glob.iglob(place + '**/**', recursive=True):
|
45 |
+
full_path = file
|
46 |
+
if os.path.isdir(full_path):
|
47 |
+
continue
|
48 |
+
if len(ext_filter) != 0:
|
49 |
+
model_name, extension = os.path.splitext(file)
|
50 |
+
if extension not in ext_filter:
|
51 |
+
continue
|
52 |
+
if file not in output:
|
53 |
+
output.append(full_path)
|
54 |
+
|
55 |
+
if model_url is not None and len(output) == 0:
|
56 |
+
if download_name is not None:
|
57 |
+
dl = load_file_from_url(model_url, model_path, True, download_name)
|
58 |
+
output.append(dl)
|
59 |
+
else:
|
60 |
+
output.append(model_url)
|
61 |
+
|
62 |
+
except Exception:
|
63 |
+
pass
|
64 |
+
|
65 |
+
return output
|
66 |
+
|
67 |
+
|
68 |
+
def friendly_name(file: str):
|
69 |
+
if "http" in file:
|
70 |
+
file = urlparse(file).path
|
71 |
+
|
72 |
+
file = os.path.basename(file)
|
73 |
+
model_name, extension = os.path.splitext(file)
|
74 |
+
return model_name
|
75 |
+
|
76 |
+
|
77 |
+
def cleanup_models():
|
78 |
+
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
|
79 |
+
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
|
80 |
+
# somehow auto-register and just do these things...
|
81 |
+
root_path = script_path
|
82 |
+
src_path = models_path
|
83 |
+
dest_path = os.path.join(models_path, "Stable-diffusion")
|
84 |
+
move_files(src_path, dest_path, ".ckpt")
|
85 |
+
src_path = os.path.join(root_path, "ESRGAN")
|
86 |
+
dest_path = os.path.join(models_path, "ESRGAN")
|
87 |
+
move_files(src_path, dest_path)
|
88 |
+
src_path = os.path.join(root_path, "gfpgan")
|
89 |
+
dest_path = os.path.join(models_path, "GFPGAN")
|
90 |
+
move_files(src_path, dest_path)
|
91 |
+
src_path = os.path.join(root_path, "SwinIR")
|
92 |
+
dest_path = os.path.join(models_path, "SwinIR")
|
93 |
+
move_files(src_path, dest_path)
|
94 |
+
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
|
95 |
+
dest_path = os.path.join(models_path, "LDSR")
|
96 |
+
move_files(src_path, dest_path)
|
97 |
+
|
98 |
+
|
99 |
+
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
100 |
+
try:
|
101 |
+
if not os.path.exists(dest_path):
|
102 |
+
os.makedirs(dest_path)
|
103 |
+
if os.path.exists(src_path):
|
104 |
+
for file in os.listdir(src_path):
|
105 |
+
fullpath = os.path.join(src_path, file)
|
106 |
+
if os.path.isfile(fullpath):
|
107 |
+
if ext_filter is not None:
|
108 |
+
if ext_filter not in file:
|
109 |
+
continue
|
110 |
+
print(f"Moving {file} from {src_path} to {dest_path}.")
|
111 |
+
try:
|
112 |
+
shutil.move(fullpath, dest_path)
|
113 |
+
except:
|
114 |
+
pass
|
115 |
+
if len(os.listdir(src_path)) == 0:
|
116 |
+
print(f"Removing empty folder: {src_path}")
|
117 |
+
shutil.rmtree(src_path, True)
|
118 |
+
except:
|
119 |
+
pass
|
120 |
+
|
121 |
+
|
122 |
+
def load_upscalers():
|
123 |
+
sd = shared.script_path
|
124 |
+
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
125 |
+
# so we'll try to import any _model.py files before looking in __subclasses__
|
126 |
+
modules_dir = os.path.join(sd, "modules")
|
127 |
+
for file in os.listdir(modules_dir):
|
128 |
+
if "_model.py" in file:
|
129 |
+
model_name = file.replace("_model.py", "")
|
130 |
+
full_model = f"modules.{model_name}_model"
|
131 |
+
try:
|
132 |
+
importlib.import_module(full_model)
|
133 |
+
except:
|
134 |
+
pass
|
135 |
+
datas = []
|
136 |
+
c_o = vars(shared.cmd_opts)
|
137 |
+
for cls in Upscaler.__subclasses__():
|
138 |
+
name = cls.__name__
|
139 |
+
module_name = cls.__module__
|
140 |
+
module = importlib.import_module(module_name)
|
141 |
+
class_ = getattr(module, name)
|
142 |
+
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
143 |
+
opt_string = None
|
144 |
+
try:
|
145 |
+
if cmd_name in c_o:
|
146 |
+
opt_string = c_o[cmd_name]
|
147 |
+
except:
|
148 |
+
pass
|
149 |
+
scaler = class_(opt_string)
|
150 |
+
for child in scaler.scalers:
|
151 |
+
datas.append(child)
|
152 |
+
|
153 |
+
shared.sd_upscalers = datas
|
modules/paths.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
6 |
+
models_path = os.path.join(script_path, "models")
|
7 |
+
sys.path.insert(0, script_path)
|
8 |
+
|
9 |
+
# search for directory of stable diffusion in following places
|
10 |
+
sd_path = None
|
11 |
+
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
|
12 |
+
for possible_sd_path in possible_sd_paths:
|
13 |
+
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
14 |
+
sd_path = os.path.abspath(possible_sd_path)
|
15 |
+
|
16 |
+
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
|
17 |
+
|
18 |
+
path_dirs = [
|
19 |
+
(sd_path, 'ldm', 'Stable Diffusion', []),
|
20 |
+
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
|
21 |
+
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
22 |
+
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
23 |
+
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
24 |
+
]
|
25 |
+
|
26 |
+
paths = {}
|
27 |
+
|
28 |
+
for d, must_exist, what, options in path_dirs:
|
29 |
+
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
|
30 |
+
if not os.path.exists(must_exist_path):
|
31 |
+
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
|
32 |
+
else:
|
33 |
+
d = os.path.abspath(d)
|
34 |
+
if "atstart" in options:
|
35 |
+
sys.path.insert(0, d)
|
36 |
+
else:
|
37 |
+
sys.path.append(d)
|
38 |
+
paths[what] = d
|
modules/processing.py
ADDED
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image, ImageFilter, ImageOps
|
9 |
+
import random
|
10 |
+
import cv2
|
11 |
+
from skimage import exposure
|
12 |
+
|
13 |
+
import modules.sd_hijack
|
14 |
+
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
15 |
+
from modules.sd_hijack import model_hijack
|
16 |
+
from modules.shared import opts, cmd_opts, state
|
17 |
+
import modules.shared as shared
|
18 |
+
import modules.face_restoration
|
19 |
+
import modules.images as images
|
20 |
+
import modules.styles
|
21 |
+
import logging
|
22 |
+
|
23 |
+
|
24 |
+
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
25 |
+
opt_C = 4
|
26 |
+
opt_f = 8
|
27 |
+
|
28 |
+
|
29 |
+
def setup_color_correction(image):
|
30 |
+
logging.info("Calibrating color correction.")
|
31 |
+
correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
|
32 |
+
return correction_target
|
33 |
+
|
34 |
+
|
35 |
+
def apply_color_correction(correction, image):
|
36 |
+
logging.info("Applying color correction.")
|
37 |
+
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
38 |
+
cv2.cvtColor(
|
39 |
+
np.asarray(image),
|
40 |
+
cv2.COLOR_RGB2LAB
|
41 |
+
),
|
42 |
+
correction,
|
43 |
+
channel_axis=2
|
44 |
+
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
45 |
+
|
46 |
+
return image
|
47 |
+
|
48 |
+
|
49 |
+
class StableDiffusionProcessing:
|
50 |
+
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
|
51 |
+
self.sd_model = sd_model
|
52 |
+
self.outpath_samples: str = outpath_samples
|
53 |
+
self.outpath_grids: str = outpath_grids
|
54 |
+
self.prompt: str = prompt
|
55 |
+
self.prompt_for_display: str = None
|
56 |
+
self.negative_prompt: str = (negative_prompt or "")
|
57 |
+
self.styles: list = styles or []
|
58 |
+
self.seed: int = seed
|
59 |
+
self.subseed: int = subseed
|
60 |
+
self.subseed_strength: float = subseed_strength
|
61 |
+
self.seed_resize_from_h: int = seed_resize_from_h
|
62 |
+
self.seed_resize_from_w: int = seed_resize_from_w
|
63 |
+
self.sampler_index: int = sampler_index
|
64 |
+
self.batch_size: int = batch_size
|
65 |
+
self.n_iter: int = n_iter
|
66 |
+
self.steps: int = steps
|
67 |
+
self.cfg_scale: float = cfg_scale
|
68 |
+
self.width: int = width
|
69 |
+
self.height: int = height
|
70 |
+
self.restore_faces: bool = restore_faces
|
71 |
+
self.tiling: bool = tiling
|
72 |
+
self.do_not_save_samples: bool = do_not_save_samples
|
73 |
+
self.do_not_save_grid: bool = do_not_save_grid
|
74 |
+
self.extra_generation_params: dict = extra_generation_params or {}
|
75 |
+
self.overlay_images = overlay_images
|
76 |
+
self.eta = eta
|
77 |
+
self.paste_to = None
|
78 |
+
self.color_corrections = None
|
79 |
+
self.denoising_strength: float = 0
|
80 |
+
self.sampler_noise_scheduler_override = None
|
81 |
+
self.ddim_discretize = opts.ddim_discretize
|
82 |
+
self.s_churn = opts.s_churn
|
83 |
+
self.s_tmin = opts.s_tmin
|
84 |
+
self.s_tmax = float('inf') # not representable as a standard ui option
|
85 |
+
self.s_noise = opts.s_noise
|
86 |
+
|
87 |
+
if not seed_enable_extras:
|
88 |
+
self.subseed = -1
|
89 |
+
self.subseed_strength = 0
|
90 |
+
self.seed_resize_from_h = 0
|
91 |
+
self.seed_resize_from_w = 0
|
92 |
+
|
93 |
+
def init(self, all_prompts, all_seeds, all_subseeds):
|
94 |
+
pass
|
95 |
+
|
96 |
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
97 |
+
raise NotImplementedError()
|
98 |
+
|
99 |
+
|
100 |
+
class Processed:
|
101 |
+
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
102 |
+
self.images = images_list
|
103 |
+
self.prompt = p.prompt
|
104 |
+
self.negative_prompt = p.negative_prompt
|
105 |
+
self.seed = seed
|
106 |
+
self.subseed = subseed
|
107 |
+
self.subseed_strength = p.subseed_strength
|
108 |
+
self.info = info
|
109 |
+
self.width = p.width
|
110 |
+
self.height = p.height
|
111 |
+
self.sampler_index = p.sampler_index
|
112 |
+
self.sampler = sd_samplers.samplers[p.sampler_index].name
|
113 |
+
self.cfg_scale = p.cfg_scale
|
114 |
+
self.steps = p.steps
|
115 |
+
self.batch_size = p.batch_size
|
116 |
+
self.restore_faces = p.restore_faces
|
117 |
+
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
118 |
+
self.sd_model_hash = shared.sd_model.sd_model_hash
|
119 |
+
self.seed_resize_from_w = p.seed_resize_from_w
|
120 |
+
self.seed_resize_from_h = p.seed_resize_from_h
|
121 |
+
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
122 |
+
self.extra_generation_params = p.extra_generation_params
|
123 |
+
self.index_of_first_image = index_of_first_image
|
124 |
+
self.styles = p.styles
|
125 |
+
self.job_timestamp = state.job_timestamp
|
126 |
+
|
127 |
+
self.eta = p.eta
|
128 |
+
self.ddim_discretize = p.ddim_discretize
|
129 |
+
self.s_churn = p.s_churn
|
130 |
+
self.s_tmin = p.s_tmin
|
131 |
+
self.s_tmax = p.s_tmax
|
132 |
+
self.s_noise = p.s_noise
|
133 |
+
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
134 |
+
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
135 |
+
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
136 |
+
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
|
137 |
+
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
138 |
+
|
139 |
+
self.all_prompts = all_prompts or [self.prompt]
|
140 |
+
self.all_seeds = all_seeds or [self.seed]
|
141 |
+
self.all_subseeds = all_subseeds or [self.subseed]
|
142 |
+
self.infotexts = infotexts or [info]
|
143 |
+
|
144 |
+
def js(self):
|
145 |
+
obj = {
|
146 |
+
"prompt": self.prompt,
|
147 |
+
"all_prompts": self.all_prompts,
|
148 |
+
"negative_prompt": self.negative_prompt,
|
149 |
+
"seed": self.seed,
|
150 |
+
"all_seeds": self.all_seeds,
|
151 |
+
"subseed": self.subseed,
|
152 |
+
"all_subseeds": self.all_subseeds,
|
153 |
+
"subseed_strength": self.subseed_strength,
|
154 |
+
"width": self.width,
|
155 |
+
"height": self.height,
|
156 |
+
"sampler_index": self.sampler_index,
|
157 |
+
"sampler": self.sampler,
|
158 |
+
"cfg_scale": self.cfg_scale,
|
159 |
+
"steps": self.steps,
|
160 |
+
"batch_size": self.batch_size,
|
161 |
+
"restore_faces": self.restore_faces,
|
162 |
+
"face_restoration_model": self.face_restoration_model,
|
163 |
+
"sd_model_hash": self.sd_model_hash,
|
164 |
+
"seed_resize_from_w": self.seed_resize_from_w,
|
165 |
+
"seed_resize_from_h": self.seed_resize_from_h,
|
166 |
+
"denoising_strength": self.denoising_strength,
|
167 |
+
"extra_generation_params": self.extra_generation_params,
|
168 |
+
"index_of_first_image": self.index_of_first_image,
|
169 |
+
"infotexts": self.infotexts,
|
170 |
+
"styles": self.styles,
|
171 |
+
"job_timestamp": self.job_timestamp,
|
172 |
+
}
|
173 |
+
|
174 |
+
return json.dumps(obj)
|
175 |
+
|
176 |
+
def infotext(self, p: StableDiffusionProcessing, index):
|
177 |
+
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
178 |
+
|
179 |
+
|
180 |
+
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
181 |
+
def slerp(val, low, high):
|
182 |
+
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
183 |
+
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
184 |
+
dot = (low_norm*high_norm).sum(1)
|
185 |
+
|
186 |
+
if dot.mean() > 0.9995:
|
187 |
+
return low * val + high * (1 - val)
|
188 |
+
|
189 |
+
omega = torch.acos(dot)
|
190 |
+
so = torch.sin(omega)
|
191 |
+
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
192 |
+
return res
|
193 |
+
|
194 |
+
|
195 |
+
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
196 |
+
xs = []
|
197 |
+
|
198 |
+
# if we have multiple seeds, this means we are working with batch size>1; this then
|
199 |
+
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
200 |
+
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
201 |
+
# produce the same images as with two batches [100], [101].
|
202 |
+
if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
|
203 |
+
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
204 |
+
else:
|
205 |
+
sampler_noises = None
|
206 |
+
|
207 |
+
for i, seed in enumerate(seeds):
|
208 |
+
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
209 |
+
|
210 |
+
subnoise = None
|
211 |
+
if subseeds is not None:
|
212 |
+
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
213 |
+
|
214 |
+
subnoise = devices.randn(subseed, noise_shape)
|
215 |
+
|
216 |
+
# randn results depend on device; gpu and cpu get different results for same seed;
|
217 |
+
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
218 |
+
# but the original script had it like this, so I do not dare change it for now because
|
219 |
+
# it will break everyone's seeds.
|
220 |
+
noise = devices.randn(seed, noise_shape)
|
221 |
+
|
222 |
+
if subnoise is not None:
|
223 |
+
noise = slerp(subseed_strength, noise, subnoise)
|
224 |
+
|
225 |
+
if noise_shape != shape:
|
226 |
+
x = devices.randn(seed, shape)
|
227 |
+
dx = (shape[2] - noise_shape[2]) // 2
|
228 |
+
dy = (shape[1] - noise_shape[1]) // 2
|
229 |
+
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
230 |
+
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
231 |
+
tx = 0 if dx < 0 else dx
|
232 |
+
ty = 0 if dy < 0 else dy
|
233 |
+
dx = max(-dx, 0)
|
234 |
+
dy = max(-dy, 0)
|
235 |
+
|
236 |
+
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
|
237 |
+
noise = x
|
238 |
+
|
239 |
+
if sampler_noises is not None:
|
240 |
+
cnt = p.sampler.number_of_needed_noises(p)
|
241 |
+
|
242 |
+
for j in range(cnt):
|
243 |
+
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
244 |
+
|
245 |
+
xs.append(noise)
|
246 |
+
|
247 |
+
if sampler_noises is not None:
|
248 |
+
p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
|
249 |
+
|
250 |
+
x = torch.stack(xs).to(shared.device)
|
251 |
+
return x
|
252 |
+
|
253 |
+
|
254 |
+
def get_fixed_seed(seed):
|
255 |
+
if seed is None or seed == '' or seed == -1:
|
256 |
+
return int(random.randrange(4294967294))
|
257 |
+
|
258 |
+
return seed
|
259 |
+
|
260 |
+
|
261 |
+
def fix_seed(p):
|
262 |
+
p.seed = get_fixed_seed(p.seed)
|
263 |
+
p.subseed = get_fixed_seed(p.subseed)
|
264 |
+
|
265 |
+
|
266 |
+
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
267 |
+
index = position_in_batch + iteration * p.batch_size
|
268 |
+
|
269 |
+
generation_params = {
|
270 |
+
"Steps": p.steps,
|
271 |
+
"Sampler": sd_samplers.samplers[p.sampler_index].name,
|
272 |
+
"CFG scale": p.cfg_scale,
|
273 |
+
"Seed": all_seeds[index],
|
274 |
+
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
275 |
+
"Size": f"{p.width}x{p.height}",
|
276 |
+
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
277 |
+
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
278 |
+
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
279 |
+
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
280 |
+
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
281 |
+
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
282 |
+
"Denoising strength": getattr(p, 'denoising_strength', None),
|
283 |
+
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
284 |
+
}
|
285 |
+
|
286 |
+
generation_params.update(p.extra_generation_params)
|
287 |
+
|
288 |
+
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
|
289 |
+
|
290 |
+
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
291 |
+
|
292 |
+
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
293 |
+
|
294 |
+
|
295 |
+
def process_images(p: StableDiffusionProcessing) -> Processed:
|
296 |
+
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
297 |
+
|
298 |
+
if type(p.prompt) == list:
|
299 |
+
assert(len(p.prompt) > 0)
|
300 |
+
else:
|
301 |
+
assert p.prompt is not None
|
302 |
+
|
303 |
+
devices.torch_gc()
|
304 |
+
|
305 |
+
seed = get_fixed_seed(p.seed)
|
306 |
+
subseed = get_fixed_seed(p.subseed)
|
307 |
+
|
308 |
+
if p.outpath_samples is not None:
|
309 |
+
os.makedirs(p.outpath_samples, exist_ok=True)
|
310 |
+
|
311 |
+
if p.outpath_grids is not None:
|
312 |
+
os.makedirs(p.outpath_grids, exist_ok=True)
|
313 |
+
|
314 |
+
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
315 |
+
|
316 |
+
comments = {}
|
317 |
+
|
318 |
+
shared.prompt_styles.apply_styles(p)
|
319 |
+
|
320 |
+
if type(p.prompt) == list:
|
321 |
+
all_prompts = p.prompt
|
322 |
+
else:
|
323 |
+
all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
324 |
+
|
325 |
+
if type(seed) == list:
|
326 |
+
all_seeds = seed
|
327 |
+
else:
|
328 |
+
all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
|
329 |
+
|
330 |
+
if type(subseed) == list:
|
331 |
+
all_subseeds = subseed
|
332 |
+
else:
|
333 |
+
all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
|
334 |
+
|
335 |
+
def infotext(iteration=0, position_in_batch=0):
|
336 |
+
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
337 |
+
|
338 |
+
if os.path.exists(cmd_opts.embeddings_dir):
|
339 |
+
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
340 |
+
|
341 |
+
infotexts = []
|
342 |
+
output_images = []
|
343 |
+
|
344 |
+
with torch.no_grad():
|
345 |
+
with devices.autocast():
|
346 |
+
p.init(all_prompts, all_seeds, all_subseeds)
|
347 |
+
|
348 |
+
if state.job_count == -1:
|
349 |
+
state.job_count = p.n_iter
|
350 |
+
|
351 |
+
for n in range(p.n_iter):
|
352 |
+
if state.interrupted:
|
353 |
+
break
|
354 |
+
|
355 |
+
prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
356 |
+
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
357 |
+
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
358 |
+
|
359 |
+
if (len(prompts) == 0):
|
360 |
+
break
|
361 |
+
|
362 |
+
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
|
363 |
+
#c = p.sd_model.get_learned_conditioning(prompts)
|
364 |
+
with devices.autocast():
|
365 |
+
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
366 |
+
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
367 |
+
|
368 |
+
if len(model_hijack.comments) > 0:
|
369 |
+
for comment in model_hijack.comments:
|
370 |
+
comments[comment] = 1
|
371 |
+
|
372 |
+
if p.n_iter > 1:
|
373 |
+
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
374 |
+
|
375 |
+
with devices.autocast():
|
376 |
+
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
377 |
+
|
378 |
+
if state.interrupted:
|
379 |
+
|
380 |
+
# if we are interruped, sample returns just noise
|
381 |
+
# use the image collected previously in sampler loop
|
382 |
+
samples_ddim = shared.state.current_latent
|
383 |
+
|
384 |
+
samples_ddim = samples_ddim.to(devices.dtype)
|
385 |
+
|
386 |
+
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
|
387 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
388 |
+
|
389 |
+
del samples_ddim
|
390 |
+
|
391 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
392 |
+
lowvram.send_everything_to_cpu()
|
393 |
+
|
394 |
+
devices.torch_gc()
|
395 |
+
|
396 |
+
if opts.filter_nsfw:
|
397 |
+
import modules.safety as safety
|
398 |
+
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
399 |
+
|
400 |
+
for i, x_sample in enumerate(x_samples_ddim):
|
401 |
+
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
402 |
+
x_sample = x_sample.astype(np.uint8)
|
403 |
+
|
404 |
+
if p.restore_faces:
|
405 |
+
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
406 |
+
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
407 |
+
|
408 |
+
devices.torch_gc()
|
409 |
+
|
410 |
+
x_sample = modules.face_restoration.restore_faces(x_sample)
|
411 |
+
devices.torch_gc()
|
412 |
+
|
413 |
+
image = Image.fromarray(x_sample)
|
414 |
+
|
415 |
+
if p.color_corrections is not None and i < len(p.color_corrections):
|
416 |
+
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
417 |
+
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
|
418 |
+
image = apply_color_correction(p.color_corrections[i], image)
|
419 |
+
|
420 |
+
if p.overlay_images is not None and i < len(p.overlay_images):
|
421 |
+
overlay = p.overlay_images[i]
|
422 |
+
|
423 |
+
if p.paste_to is not None:
|
424 |
+
x, y, w, h = p.paste_to
|
425 |
+
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
426 |
+
image = images.resize_image(1, image, w, h)
|
427 |
+
base_image.paste(image, (x, y))
|
428 |
+
image = base_image
|
429 |
+
|
430 |
+
image = image.convert('RGBA')
|
431 |
+
image.alpha_composite(overlay)
|
432 |
+
image = image.convert('RGB')
|
433 |
+
|
434 |
+
if opts.samples_save and not p.do_not_save_samples:
|
435 |
+
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
436 |
+
|
437 |
+
text = infotext(n, i)
|
438 |
+
infotexts.append(text)
|
439 |
+
image.info["parameters"] = text
|
440 |
+
output_images.append(image)
|
441 |
+
|
442 |
+
del x_samples_ddim
|
443 |
+
|
444 |
+
devices.torch_gc()
|
445 |
+
|
446 |
+
state.nextjob()
|
447 |
+
|
448 |
+
p.color_corrections = None
|
449 |
+
|
450 |
+
index_of_first_image = 0
|
451 |
+
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
|
452 |
+
if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
|
453 |
+
grid = images.image_grid(output_images, p.batch_size)
|
454 |
+
|
455 |
+
if opts.return_grid:
|
456 |
+
text = infotext()
|
457 |
+
infotexts.insert(0, text)
|
458 |
+
grid.info["parameters"] = text
|
459 |
+
output_images.insert(0, grid)
|
460 |
+
index_of_first_image = 1
|
461 |
+
|
462 |
+
if opts.grid_save:
|
463 |
+
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
464 |
+
|
465 |
+
devices.torch_gc()
|
466 |
+
return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
467 |
+
|
468 |
+
|
469 |
+
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
470 |
+
sampler = None
|
471 |
+
firstphase_width = 0
|
472 |
+
firstphase_height = 0
|
473 |
+
firstphase_width_truncated = 0
|
474 |
+
firstphase_height_truncated = 0
|
475 |
+
|
476 |
+
def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
|
477 |
+
super().__init__(**kwargs)
|
478 |
+
self.enable_hr = enable_hr
|
479 |
+
self.scale_latent = scale_latent
|
480 |
+
self.denoising_strength = denoising_strength
|
481 |
+
|
482 |
+
def init(self, all_prompts, all_seeds, all_subseeds):
|
483 |
+
if self.enable_hr:
|
484 |
+
if state.job_count == -1:
|
485 |
+
state.job_count = self.n_iter * 2
|
486 |
+
else:
|
487 |
+
state.job_count = state.job_count * 2
|
488 |
+
|
489 |
+
desired_pixel_count = 512 * 512
|
490 |
+
actual_pixel_count = self.width * self.height
|
491 |
+
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
492 |
+
|
493 |
+
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
|
494 |
+
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
|
495 |
+
self.firstphase_width_truncated = int(scale * self.width)
|
496 |
+
self.firstphase_height_truncated = int(scale * self.height)
|
497 |
+
|
498 |
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
499 |
+
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
500 |
+
|
501 |
+
if not self.enable_hr:
|
502 |
+
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
503 |
+
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
504 |
+
return samples
|
505 |
+
|
506 |
+
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
507 |
+
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
508 |
+
|
509 |
+
truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
|
510 |
+
truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
|
511 |
+
|
512 |
+
samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
|
513 |
+
|
514 |
+
if self.scale_latent:
|
515 |
+
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
516 |
+
else:
|
517 |
+
decoded_samples = self.sd_model.decode_first_stage(samples)
|
518 |
+
|
519 |
+
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
|
520 |
+
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
|
521 |
+
else:
|
522 |
+
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
523 |
+
|
524 |
+
batch_images = []
|
525 |
+
for i, x_sample in enumerate(lowres_samples):
|
526 |
+
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
527 |
+
x_sample = x_sample.astype(np.uint8)
|
528 |
+
image = Image.fromarray(x_sample)
|
529 |
+
image = images.resize_image(0, image, self.width, self.height)
|
530 |
+
image = np.array(image).astype(np.float32) / 255.0
|
531 |
+
image = np.moveaxis(image, 2, 0)
|
532 |
+
batch_images.append(image)
|
533 |
+
|
534 |
+
decoded_samples = torch.from_numpy(np.array(batch_images))
|
535 |
+
decoded_samples = decoded_samples.to(shared.device)
|
536 |
+
decoded_samples = 2. * decoded_samples - 1.
|
537 |
+
|
538 |
+
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
539 |
+
|
540 |
+
shared.state.nextjob()
|
541 |
+
|
542 |
+
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
543 |
+
|
544 |
+
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
545 |
+
|
546 |
+
# GC now before running the next img2img to prevent running out of memory
|
547 |
+
x = None
|
548 |
+
devices.torch_gc()
|
549 |
+
|
550 |
+
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
|
551 |
+
|
552 |
+
return samples
|
553 |
+
|
554 |
+
|
555 |
+
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
556 |
+
sampler = None
|
557 |
+
|
558 |
+
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
|
559 |
+
super().__init__(**kwargs)
|
560 |
+
|
561 |
+
self.init_images = init_images
|
562 |
+
self.resize_mode: int = resize_mode
|
563 |
+
self.denoising_strength: float = denoising_strength
|
564 |
+
self.init_latent = None
|
565 |
+
self.image_mask = mask
|
566 |
+
#self.image_unblurred_mask = None
|
567 |
+
self.latent_mask = None
|
568 |
+
self.mask_for_overlay = None
|
569 |
+
self.mask_blur = mask_blur
|
570 |
+
self.inpainting_fill = inpainting_fill
|
571 |
+
self.inpaint_full_res = inpaint_full_res
|
572 |
+
self.inpaint_full_res_padding = inpaint_full_res_padding
|
573 |
+
self.inpainting_mask_invert = inpainting_mask_invert
|
574 |
+
self.mask = None
|
575 |
+
self.nmask = None
|
576 |
+
|
577 |
+
def init(self, all_prompts, all_seeds, all_subseeds):
|
578 |
+
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
579 |
+
crop_region = None
|
580 |
+
|
581 |
+
if self.image_mask is not None:
|
582 |
+
self.image_mask = self.image_mask.convert('L')
|
583 |
+
|
584 |
+
if self.inpainting_mask_invert:
|
585 |
+
self.image_mask = ImageOps.invert(self.image_mask)
|
586 |
+
|
587 |
+
#self.image_unblurred_mask = self.image_mask
|
588 |
+
|
589 |
+
if self.mask_blur > 0:
|
590 |
+
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
591 |
+
|
592 |
+
if self.inpaint_full_res:
|
593 |
+
self.mask_for_overlay = self.image_mask
|
594 |
+
mask = self.image_mask.convert('L')
|
595 |
+
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
|
596 |
+
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
597 |
+
x1, y1, x2, y2 = crop_region
|
598 |
+
|
599 |
+
mask = mask.crop(crop_region)
|
600 |
+
self.image_mask = images.resize_image(2, mask, self.width, self.height)
|
601 |
+
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
602 |
+
else:
|
603 |
+
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
|
604 |
+
np_mask = np.array(self.image_mask)
|
605 |
+
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
|
606 |
+
self.mask_for_overlay = Image.fromarray(np_mask)
|
607 |
+
|
608 |
+
self.overlay_images = []
|
609 |
+
|
610 |
+
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
|
611 |
+
|
612 |
+
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
|
613 |
+
if add_color_corrections:
|
614 |
+
self.color_corrections = []
|
615 |
+
imgs = []
|
616 |
+
for img in self.init_images:
|
617 |
+
image = img.convert("RGB")
|
618 |
+
|
619 |
+
if crop_region is None:
|
620 |
+
image = images.resize_image(self.resize_mode, image, self.width, self.height)
|
621 |
+
|
622 |
+
if self.image_mask is not None:
|
623 |
+
image_masked = Image.new('RGBa', (image.width, image.height))
|
624 |
+
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
625 |
+
|
626 |
+
self.overlay_images.append(image_masked.convert('RGBA'))
|
627 |
+
|
628 |
+
if crop_region is not None:
|
629 |
+
image = image.crop(crop_region)
|
630 |
+
image = images.resize_image(2, image, self.width, self.height)
|
631 |
+
|
632 |
+
if self.image_mask is not None:
|
633 |
+
if self.inpainting_fill != 1:
|
634 |
+
image = masking.fill(image, latent_mask)
|
635 |
+
|
636 |
+
if add_color_corrections:
|
637 |
+
self.color_corrections.append(setup_color_correction(image))
|
638 |
+
|
639 |
+
image = np.array(image).astype(np.float32) / 255.0
|
640 |
+
image = np.moveaxis(image, 2, 0)
|
641 |
+
|
642 |
+
imgs.append(image)
|
643 |
+
|
644 |
+
if len(imgs) == 1:
|
645 |
+
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
|
646 |
+
if self.overlay_images is not None:
|
647 |
+
self.overlay_images = self.overlay_images * self.batch_size
|
648 |
+
elif len(imgs) <= self.batch_size:
|
649 |
+
self.batch_size = len(imgs)
|
650 |
+
batch_images = np.array(imgs)
|
651 |
+
else:
|
652 |
+
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
653 |
+
|
654 |
+
image = torch.from_numpy(batch_images)
|
655 |
+
image = 2. * image - 1.
|
656 |
+
image = image.to(shared.device)
|
657 |
+
|
658 |
+
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
659 |
+
|
660 |
+
if self.image_mask is not None:
|
661 |
+
init_mask = latent_mask
|
662 |
+
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
663 |
+
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
664 |
+
latmask = latmask[0]
|
665 |
+
latmask = np.around(latmask)
|
666 |
+
latmask = np.tile(latmask[None], (4, 1, 1))
|
667 |
+
|
668 |
+
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
669 |
+
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
|
670 |
+
|
671 |
+
# this needs to be fixed to be done in sample() using actual seeds for batches
|
672 |
+
if self.inpainting_fill == 2:
|
673 |
+
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
674 |
+
elif self.inpainting_fill == 3:
|
675 |
+
self.init_latent = self.init_latent * self.mask
|
676 |
+
|
677 |
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
678 |
+
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
679 |
+
|
680 |
+
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
|
681 |
+
|
682 |
+
if self.mask is not None:
|
683 |
+
samples = samples * self.nmask + self.init_latent * self.mask
|
684 |
+
|
685 |
+
del x
|
686 |
+
devices.torch_gc()
|
687 |
+
|
688 |
+
return samples
|
modules/prompt_parser.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from collections import namedtuple
|
3 |
+
from typing import List
|
4 |
+
import lark
|
5 |
+
|
6 |
+
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
|
7 |
+
# will be represented with prompt_schedule like this (assuming steps=100):
|
8 |
+
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
|
9 |
+
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
|
10 |
+
# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
|
11 |
+
# [75, 'fantasy landscape with a lake and an oak in background masterful']
|
12 |
+
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
|
13 |
+
|
14 |
+
schedule_parser = lark.Lark(r"""
|
15 |
+
!start: (prompt | /[][():]/+)*
|
16 |
+
prompt: (emphasized | scheduled | plain | WHITESPACE)*
|
17 |
+
!emphasized: "(" prompt ")"
|
18 |
+
| "(" prompt ":" prompt ")"
|
19 |
+
| "[" prompt "]"
|
20 |
+
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
21 |
+
WHITESPACE: /\s+/
|
22 |
+
plain: /([^\\\[\]():]|\\.)+/
|
23 |
+
%import common.SIGNED_NUMBER -> NUMBER
|
24 |
+
""")
|
25 |
+
|
26 |
+
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
27 |
+
"""
|
28 |
+
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
29 |
+
>>> g("test")
|
30 |
+
[[10, 'test']]
|
31 |
+
>>> g("a [b:3]")
|
32 |
+
[[3, 'a '], [10, 'a b']]
|
33 |
+
>>> g("a [b: 3]")
|
34 |
+
[[3, 'a '], [10, 'a b']]
|
35 |
+
>>> g("a [[[b]]:2]")
|
36 |
+
[[2, 'a '], [10, 'a [[b]]']]
|
37 |
+
>>> g("[(a:2):3]")
|
38 |
+
[[3, ''], [10, '(a:2)']]
|
39 |
+
>>> g("a [b : c : 1] d")
|
40 |
+
[[1, 'a b d'], [10, 'a c d']]
|
41 |
+
>>> g("a[b:[c:d:2]:1]e")
|
42 |
+
[[1, 'abe'], [2, 'ace'], [10, 'ade']]
|
43 |
+
>>> g("a [unbalanced")
|
44 |
+
[[10, 'a [unbalanced']]
|
45 |
+
>>> g("a [b:.5] c")
|
46 |
+
[[5, 'a c'], [10, 'a b c']]
|
47 |
+
>>> g("a [{b|d{:.5] c") # not handling this right now
|
48 |
+
[[5, 'a c'], [10, 'a {b|d{ c']]
|
49 |
+
>>> g("((a][:b:c [d:3]")
|
50 |
+
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
51 |
+
"""
|
52 |
+
|
53 |
+
def collect_steps(steps, tree):
|
54 |
+
l = [steps]
|
55 |
+
class CollectSteps(lark.Visitor):
|
56 |
+
def scheduled(self, tree):
|
57 |
+
tree.children[-1] = float(tree.children[-1])
|
58 |
+
if tree.children[-1] < 1:
|
59 |
+
tree.children[-1] *= steps
|
60 |
+
tree.children[-1] = min(steps, int(tree.children[-1]))
|
61 |
+
l.append(tree.children[-1])
|
62 |
+
CollectSteps().visit(tree)
|
63 |
+
return sorted(set(l))
|
64 |
+
|
65 |
+
def at_step(step, tree):
|
66 |
+
class AtStep(lark.Transformer):
|
67 |
+
def scheduled(self, args):
|
68 |
+
before, after, _, when = args
|
69 |
+
yield before or () if step <= when else after
|
70 |
+
def start(self, args):
|
71 |
+
def flatten(x):
|
72 |
+
if type(x) == str:
|
73 |
+
yield x
|
74 |
+
else:
|
75 |
+
for gen in x:
|
76 |
+
yield from flatten(gen)
|
77 |
+
return ''.join(flatten(args))
|
78 |
+
def plain(self, args):
|
79 |
+
yield args[0].value
|
80 |
+
def __default__(self, data, children, meta):
|
81 |
+
for child in children:
|
82 |
+
yield from child
|
83 |
+
return AtStep().transform(tree)
|
84 |
+
|
85 |
+
def get_schedule(prompt):
|
86 |
+
try:
|
87 |
+
tree = schedule_parser.parse(prompt)
|
88 |
+
except lark.exceptions.LarkError as e:
|
89 |
+
if 0:
|
90 |
+
import traceback
|
91 |
+
traceback.print_exc()
|
92 |
+
return [[steps, prompt]]
|
93 |
+
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
|
94 |
+
|
95 |
+
promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
|
96 |
+
return [promptdict[prompt] for prompt in prompts]
|
97 |
+
|
98 |
+
|
99 |
+
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
100 |
+
|
101 |
+
|
102 |
+
def get_learned_conditioning(model, prompts, steps):
|
103 |
+
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
104 |
+
and the sampling step at which this condition is to be replaced by the next one.
|
105 |
+
|
106 |
+
Input:
|
107 |
+
(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
|
108 |
+
|
109 |
+
Output:
|
110 |
+
[
|
111 |
+
[
|
112 |
+
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
|
113 |
+
],
|
114 |
+
[
|
115 |
+
ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
|
116 |
+
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
|
117 |
+
]
|
118 |
+
]
|
119 |
+
"""
|
120 |
+
res = []
|
121 |
+
|
122 |
+
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
123 |
+
cache = {}
|
124 |
+
|
125 |
+
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
126 |
+
|
127 |
+
cached = cache.get(prompt, None)
|
128 |
+
if cached is not None:
|
129 |
+
res.append(cached)
|
130 |
+
continue
|
131 |
+
|
132 |
+
texts = [x[1] for x in prompt_schedule]
|
133 |
+
conds = model.get_learned_conditioning(texts)
|
134 |
+
|
135 |
+
cond_schedule = []
|
136 |
+
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
137 |
+
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
138 |
+
|
139 |
+
cache[prompt] = cond_schedule
|
140 |
+
res.append(cond_schedule)
|
141 |
+
|
142 |
+
return res
|
143 |
+
|
144 |
+
|
145 |
+
re_AND = re.compile(r"\bAND\b")
|
146 |
+
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
|
147 |
+
|
148 |
+
def get_multicond_prompt_list(prompts):
|
149 |
+
res_indexes = []
|
150 |
+
|
151 |
+
prompt_flat_list = []
|
152 |
+
prompt_indexes = {}
|
153 |
+
|
154 |
+
for prompt in prompts:
|
155 |
+
subprompts = re_AND.split(prompt)
|
156 |
+
|
157 |
+
indexes = []
|
158 |
+
for subprompt in subprompts:
|
159 |
+
match = re_weight.search(subprompt)
|
160 |
+
|
161 |
+
text, weight = match.groups() if match is not None else (subprompt, 1.0)
|
162 |
+
|
163 |
+
weight = float(weight) if weight is not None else 1.0
|
164 |
+
|
165 |
+
index = prompt_indexes.get(text, None)
|
166 |
+
if index is None:
|
167 |
+
index = len(prompt_flat_list)
|
168 |
+
prompt_flat_list.append(text)
|
169 |
+
prompt_indexes[text] = index
|
170 |
+
|
171 |
+
indexes.append((index, weight))
|
172 |
+
|
173 |
+
res_indexes.append(indexes)
|
174 |
+
|
175 |
+
return res_indexes, prompt_flat_list, prompt_indexes
|
176 |
+
|
177 |
+
|
178 |
+
class ComposableScheduledPromptConditioning:
|
179 |
+
def __init__(self, schedules, weight=1.0):
|
180 |
+
self.schedules: List[ScheduledPromptConditioning] = schedules
|
181 |
+
self.weight: float = weight
|
182 |
+
|
183 |
+
|
184 |
+
class MulticondLearnedConditioning:
|
185 |
+
def __init__(self, shape, batch):
|
186 |
+
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
187 |
+
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
188 |
+
|
189 |
+
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
190 |
+
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
191 |
+
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
192 |
+
|
193 |
+
https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
|
194 |
+
"""
|
195 |
+
|
196 |
+
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
197 |
+
|
198 |
+
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
|
199 |
+
|
200 |
+
res = []
|
201 |
+
for indexes in res_indexes:
|
202 |
+
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
|
203 |
+
|
204 |
+
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
205 |
+
|
206 |
+
|
207 |
+
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
|
208 |
+
param = c[0][0].cond
|
209 |
+
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
210 |
+
for i, cond_schedule in enumerate(c):
|
211 |
+
target_index = 0
|
212 |
+
for current, (end_at, cond) in enumerate(cond_schedule):
|
213 |
+
if current_step <= end_at:
|
214 |
+
target_index = current
|
215 |
+
break
|
216 |
+
res[i] = cond_schedule[target_index].cond
|
217 |
+
|
218 |
+
return res
|
219 |
+
|
220 |
+
|
221 |
+
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
222 |
+
param = c.batch[0][0].schedules[0].cond
|
223 |
+
|
224 |
+
tensors = []
|
225 |
+
conds_list = []
|
226 |
+
|
227 |
+
for batch_no, composable_prompts in enumerate(c.batch):
|
228 |
+
conds_for_batch = []
|
229 |
+
|
230 |
+
for cond_index, composable_prompt in enumerate(composable_prompts):
|
231 |
+
target_index = 0
|
232 |
+
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
|
233 |
+
if current_step <= end_at:
|
234 |
+
target_index = current
|
235 |
+
break
|
236 |
+
|
237 |
+
conds_for_batch.append((len(tensors), composable_prompt.weight))
|
238 |
+
tensors.append(composable_prompt.schedules[target_index].cond)
|
239 |
+
|
240 |
+
conds_list.append(conds_for_batch)
|
241 |
+
|
242 |
+
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|
243 |
+
|
244 |
+
|
245 |
+
re_attention = re.compile(r"""
|
246 |
+
\\\(|
|
247 |
+
\\\)|
|
248 |
+
\\\[|
|
249 |
+
\\]|
|
250 |
+
\\\\|
|
251 |
+
\\|
|
252 |
+
\(|
|
253 |
+
\[|
|
254 |
+
:([+-]?[.\d]+)\)|
|
255 |
+
\)|
|
256 |
+
]|
|
257 |
+
[^\\()\[\]:]+|
|
258 |
+
:
|
259 |
+
""", re.X)
|
260 |
+
|
261 |
+
|
262 |
+
def parse_prompt_attention(text):
|
263 |
+
"""
|
264 |
+
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
|
265 |
+
Accepted tokens are:
|
266 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
267 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
268 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
269 |
+
\( - literal character '('
|
270 |
+
\[ - literal character '['
|
271 |
+
\) - literal character ')'
|
272 |
+
\] - literal character ']'
|
273 |
+
\\ - literal character '\'
|
274 |
+
anything else - just text
|
275 |
+
|
276 |
+
>>> parse_prompt_attention('normal text')
|
277 |
+
[['normal text', 1.0]]
|
278 |
+
>>> parse_prompt_attention('an (important) word')
|
279 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
280 |
+
>>> parse_prompt_attention('(unbalanced')
|
281 |
+
[['unbalanced', 1.1]]
|
282 |
+
>>> parse_prompt_attention('\(literal\]')
|
283 |
+
[['(literal]', 1.0]]
|
284 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
285 |
+
[['unnecessaryparens', 1.1]]
|
286 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
287 |
+
[['a ', 1.0],
|
288 |
+
['house', 1.5730000000000004],
|
289 |
+
[' ', 1.1],
|
290 |
+
['on', 1.0],
|
291 |
+
[' a ', 1.1],
|
292 |
+
['hill', 0.55],
|
293 |
+
[', sun, ', 1.1],
|
294 |
+
['sky', 1.4641000000000006],
|
295 |
+
['.', 1.1]]
|
296 |
+
"""
|
297 |
+
|
298 |
+
res = []
|
299 |
+
round_brackets = []
|
300 |
+
square_brackets = []
|
301 |
+
|
302 |
+
round_bracket_multiplier = 1.1
|
303 |
+
square_bracket_multiplier = 1 / 1.1
|
304 |
+
|
305 |
+
def multiply_range(start_position, multiplier):
|
306 |
+
for p in range(start_position, len(res)):
|
307 |
+
res[p][1] *= multiplier
|
308 |
+
|
309 |
+
for m in re_attention.finditer(text):
|
310 |
+
text = m.group(0)
|
311 |
+
weight = m.group(1)
|
312 |
+
|
313 |
+
if text.startswith('\\'):
|
314 |
+
res.append([text[1:], 1.0])
|
315 |
+
elif text == '(':
|
316 |
+
round_brackets.append(len(res))
|
317 |
+
elif text == '[':
|
318 |
+
square_brackets.append(len(res))
|
319 |
+
elif weight is not None and len(round_brackets) > 0:
|
320 |
+
multiply_range(round_brackets.pop(), float(weight))
|
321 |
+
elif text == ')' and len(round_brackets) > 0:
|
322 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
323 |
+
elif text == ']' and len(square_brackets) > 0:
|
324 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
325 |
+
else:
|
326 |
+
res.append([text, 1.0])
|
327 |
+
|
328 |
+
for pos in round_brackets:
|
329 |
+
multiply_range(pos, round_bracket_multiplier)
|
330 |
+
|
331 |
+
for pos in square_brackets:
|
332 |
+
multiply_range(pos, square_bracket_multiplier)
|
333 |
+
|
334 |
+
if len(res) == 0:
|
335 |
+
res = [["", 1.0]]
|
336 |
+
|
337 |
+
# merge runs of identical weights
|
338 |
+
i = 0
|
339 |
+
while i + 1 < len(res):
|
340 |
+
if res[i][1] == res[i + 1][1]:
|
341 |
+
res[i][0] += res[i + 1][0]
|
342 |
+
res.pop(i + 1)
|
343 |
+
else:
|
344 |
+
i += 1
|
345 |
+
|
346 |
+
return res
|
347 |
+
|
348 |
+
if __name__ == "__main__":
|
349 |
+
import doctest
|
350 |
+
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
|
351 |
+
else:
|
352 |
+
import torch # doctest faster
|
modules/realesrgan_model.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
from realesrgan import RealESRGANer
|
9 |
+
|
10 |
+
from modules.upscaler import Upscaler, UpscalerData
|
11 |
+
from modules.paths import models_path
|
12 |
+
from modules.shared import cmd_opts, opts
|
13 |
+
|
14 |
+
|
15 |
+
class UpscalerRealESRGAN(Upscaler):
|
16 |
+
def __init__(self, path):
|
17 |
+
self.name = "RealESRGAN"
|
18 |
+
self.model_path = os.path.join(models_path, self.name)
|
19 |
+
self.user_path = path
|
20 |
+
super().__init__()
|
21 |
+
try:
|
22 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
23 |
+
from realesrgan import RealESRGANer
|
24 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
25 |
+
self.enable = True
|
26 |
+
self.scalers = []
|
27 |
+
scalers = self.load_models(path)
|
28 |
+
for scaler in scalers:
|
29 |
+
if scaler.name in opts.realesrgan_enabled_models:
|
30 |
+
self.scalers.append(scaler)
|
31 |
+
|
32 |
+
except Exception:
|
33 |
+
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
34 |
+
print(traceback.format_exc(), file=sys.stderr)
|
35 |
+
self.enable = False
|
36 |
+
self.scalers = []
|
37 |
+
|
38 |
+
def do_upscale(self, img, path):
|
39 |
+
if not self.enable:
|
40 |
+
return img
|
41 |
+
|
42 |
+
info = self.load_model(path)
|
43 |
+
if not os.path.exists(info.data_path):
|
44 |
+
print("Unable to load RealESRGAN model: %s" % info.name)
|
45 |
+
return img
|
46 |
+
|
47 |
+
upsampler = RealESRGANer(
|
48 |
+
scale=info.scale,
|
49 |
+
model_path=info.data_path,
|
50 |
+
model=info.model(),
|
51 |
+
half=not cmd_opts.no_half,
|
52 |
+
tile=opts.ESRGAN_tile,
|
53 |
+
tile_pad=opts.ESRGAN_tile_overlap,
|
54 |
+
)
|
55 |
+
|
56 |
+
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
57 |
+
|
58 |
+
image = Image.fromarray(upsampled)
|
59 |
+
return image
|
60 |
+
|
61 |
+
def load_model(self, path):
|
62 |
+
try:
|
63 |
+
info = None
|
64 |
+
for scaler in self.scalers:
|
65 |
+
if scaler.data_path == path:
|
66 |
+
info = scaler
|
67 |
+
|
68 |
+
if info is None:
|
69 |
+
print(f"Unable to find model info: {path}")
|
70 |
+
return None
|
71 |
+
|
72 |
+
model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
73 |
+
info.data_path = model_file
|
74 |
+
return info
|
75 |
+
except Exception as e:
|
76 |
+
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
77 |
+
print(traceback.format_exc(), file=sys.stderr)
|
78 |
+
return None
|
79 |
+
|
80 |
+
def load_models(self, _):
|
81 |
+
return get_realesrgan_models(self)
|
82 |
+
|
83 |
+
|
84 |
+
def get_realesrgan_models(scaler):
|
85 |
+
try:
|
86 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
87 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
88 |
+
models = [
|
89 |
+
UpscalerData(
|
90 |
+
name="R-ESRGAN General 4xV3",
|
91 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
92 |
+
scale=4,
|
93 |
+
upscaler=scaler,
|
94 |
+
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
95 |
+
),
|
96 |
+
UpscalerData(
|
97 |
+
name="R-ESRGAN General WDN 4xV3",
|
98 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
99 |
+
scale=4,
|
100 |
+
upscaler=scaler,
|
101 |
+
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
102 |
+
),
|
103 |
+
UpscalerData(
|
104 |
+
name="R-ESRGAN AnimeVideo",
|
105 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
106 |
+
scale=4,
|
107 |
+
upscaler=scaler,
|
108 |
+
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
109 |
+
),
|
110 |
+
UpscalerData(
|
111 |
+
name="R-ESRGAN 4x+",
|
112 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
113 |
+
scale=4,
|
114 |
+
upscaler=scaler,
|
115 |
+
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
116 |
+
),
|
117 |
+
UpscalerData(
|
118 |
+
name="R-ESRGAN 4x+ Anime6B",
|
119 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
120 |
+
scale=4,
|
121 |
+
upscaler=scaler,
|
122 |
+
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
123 |
+
),
|
124 |
+
UpscalerData(
|
125 |
+
name="R-ESRGAN 2x+",
|
126 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
127 |
+
scale=2,
|
128 |
+
upscaler=scaler,
|
129 |
+
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
130 |
+
),
|
131 |
+
]
|
132 |
+
return models
|
133 |
+
except Exception as e:
|
134 |
+
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
135 |
+
print(traceback.format_exc(), file=sys.stderr)
|
modules/safety.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
3 |
+
from transformers import AutoFeatureExtractor
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import modules.shared as shared
|
7 |
+
|
8 |
+
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
9 |
+
safety_feature_extractor = None
|
10 |
+
safety_checker = None
|
11 |
+
|
12 |
+
def numpy_to_pil(images):
|
13 |
+
"""
|
14 |
+
Convert a numpy image or a batch of images to a PIL image.
|
15 |
+
"""
|
16 |
+
if images.ndim == 3:
|
17 |
+
images = images[None, ...]
|
18 |
+
images = (images * 255).round().astype("uint8")
|
19 |
+
pil_images = [Image.fromarray(image) for image in images]
|
20 |
+
|
21 |
+
return pil_images
|
22 |
+
|
23 |
+
# check and replace nsfw content
|
24 |
+
def check_safety(x_image):
|
25 |
+
global safety_feature_extractor, safety_checker
|
26 |
+
|
27 |
+
if safety_feature_extractor is None:
|
28 |
+
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
29 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
30 |
+
|
31 |
+
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
32 |
+
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
33 |
+
|
34 |
+
return x_checked_image, has_nsfw_concept
|
35 |
+
|
36 |
+
|
37 |
+
def censor_batch(x):
|
38 |
+
x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
|
39 |
+
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
|
40 |
+
x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
41 |
+
|
42 |
+
return x
|
modules/scripts.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import modules.ui as ui
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from modules.processing import StableDiffusionProcessing
|
9 |
+
from modules import shared
|
10 |
+
|
11 |
+
class Script:
|
12 |
+
filename = None
|
13 |
+
args_from = None
|
14 |
+
args_to = None
|
15 |
+
|
16 |
+
# The title of the script. This is what will be displayed in the dropdown menu.
|
17 |
+
def title(self):
|
18 |
+
raise NotImplementedError()
|
19 |
+
|
20 |
+
# How the script is displayed in the UI. See https://gradio.app/docs/#components
|
21 |
+
# for the different UI components you can use and how to create them.
|
22 |
+
# Most UI components can return a value, such as a boolean for a checkbox.
|
23 |
+
# The returned values are passed to the run method as parameters.
|
24 |
+
def ui(self, is_img2img):
|
25 |
+
pass
|
26 |
+
|
27 |
+
# Determines when the script should be shown in the dropdown menu via the
|
28 |
+
# returned value. As an example:
|
29 |
+
# is_img2img is True if the current tab is img2img, and False if it is txt2img.
|
30 |
+
# Thus, return is_img2img to only show the script on the img2img tab.
|
31 |
+
def show(self, is_img2img):
|
32 |
+
return True
|
33 |
+
|
34 |
+
# This is where the additional processing is implemented. The parameters include
|
35 |
+
# self, the model object "p" (a StableDiffusionProcessing class, see
|
36 |
+
# processing.py), and the parameters returned by the ui method.
|
37 |
+
# Custom functions can be defined here, and additional libraries can be imported
|
38 |
+
# to be used in processing. The return value should be a Processed object, which is
|
39 |
+
# what is returned by the process_images method.
|
40 |
+
def run(self, *args):
|
41 |
+
raise NotImplementedError()
|
42 |
+
|
43 |
+
# The description method is currently unused.
|
44 |
+
# To add a description that appears when hovering over the title, amend the "titles"
|
45 |
+
# dict in script.js to include the script title (returned by title) as a key, and
|
46 |
+
# your description as the value.
|
47 |
+
def describe(self):
|
48 |
+
return ""
|
49 |
+
|
50 |
+
|
51 |
+
scripts_data = []
|
52 |
+
|
53 |
+
|
54 |
+
def load_scripts(basedir):
|
55 |
+
if not os.path.exists(basedir):
|
56 |
+
return
|
57 |
+
|
58 |
+
for filename in sorted(os.listdir(basedir)):
|
59 |
+
path = os.path.join(basedir, filename)
|
60 |
+
|
61 |
+
if not os.path.isfile(path):
|
62 |
+
continue
|
63 |
+
|
64 |
+
try:
|
65 |
+
with open(path, "r", encoding="utf8") as file:
|
66 |
+
text = file.read()
|
67 |
+
|
68 |
+
from types import ModuleType
|
69 |
+
compiled = compile(text, path, 'exec')
|
70 |
+
module = ModuleType(filename)
|
71 |
+
exec(compiled, module.__dict__)
|
72 |
+
|
73 |
+
for key, script_class in module.__dict__.items():
|
74 |
+
if type(script_class) == type and issubclass(script_class, Script):
|
75 |
+
scripts_data.append((script_class, path))
|
76 |
+
|
77 |
+
except Exception:
|
78 |
+
print(f"Error loading script: {filename}", file=sys.stderr)
|
79 |
+
print(traceback.format_exc(), file=sys.stderr)
|
80 |
+
|
81 |
+
|
82 |
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
83 |
+
try:
|
84 |
+
res = func(*args, **kwargs)
|
85 |
+
return res
|
86 |
+
except Exception:
|
87 |
+
print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
|
88 |
+
print(traceback.format_exc(), file=sys.stderr)
|
89 |
+
|
90 |
+
return default
|
91 |
+
|
92 |
+
|
93 |
+
class ScriptRunner:
|
94 |
+
def __init__(self):
|
95 |
+
self.scripts = []
|
96 |
+
|
97 |
+
def setup_ui(self, is_img2img):
|
98 |
+
for script_class, path in scripts_data:
|
99 |
+
script = script_class()
|
100 |
+
script.filename = path
|
101 |
+
|
102 |
+
if not script.show(is_img2img):
|
103 |
+
continue
|
104 |
+
|
105 |
+
self.scripts.append(script)
|
106 |
+
|
107 |
+
titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
|
108 |
+
|
109 |
+
dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index")
|
110 |
+
inputs = [dropdown]
|
111 |
+
|
112 |
+
for script in self.scripts:
|
113 |
+
script.args_from = len(inputs)
|
114 |
+
script.args_to = len(inputs)
|
115 |
+
|
116 |
+
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
|
117 |
+
|
118 |
+
if controls is None:
|
119 |
+
continue
|
120 |
+
|
121 |
+
for control in controls:
|
122 |
+
control.custom_script_source = os.path.basename(script.filename)
|
123 |
+
control.visible = False
|
124 |
+
|
125 |
+
inputs += controls
|
126 |
+
script.args_to = len(inputs)
|
127 |
+
|
128 |
+
def select_script(script_index):
|
129 |
+
if 0 < script_index <= len(self.scripts):
|
130 |
+
script = self.scripts[script_index-1]
|
131 |
+
args_from = script.args_from
|
132 |
+
args_to = script.args_to
|
133 |
+
else:
|
134 |
+
args_from = 0
|
135 |
+
args_to = 0
|
136 |
+
|
137 |
+
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
|
138 |
+
|
139 |
+
dropdown.change(
|
140 |
+
fn=select_script,
|
141 |
+
inputs=[dropdown],
|
142 |
+
outputs=inputs
|
143 |
+
)
|
144 |
+
|
145 |
+
return inputs
|
146 |
+
|
147 |
+
def run(self, p: StableDiffusionProcessing, *args):
|
148 |
+
script_index = args[0]
|
149 |
+
|
150 |
+
if script_index == 0:
|
151 |
+
return None
|
152 |
+
|
153 |
+
script = self.scripts[script_index-1]
|
154 |
+
|
155 |
+
if script is None:
|
156 |
+
return None
|
157 |
+
|
158 |
+
script_args = args[script.args_from:script.args_to]
|
159 |
+
processed = script.run(p, *script_args)
|
160 |
+
|
161 |
+
shared.total_tqdm.clear()
|
162 |
+
|
163 |
+
return processed
|
164 |
+
|
165 |
+
def reload_sources(self):
|
166 |
+
for si, script in list(enumerate(self.scripts)):
|
167 |
+
with open(script.filename, "r", encoding="utf8") as file:
|
168 |
+
args_from = script.args_from
|
169 |
+
args_to = script.args_to
|
170 |
+
filename = script.filename
|
171 |
+
text = file.read()
|
172 |
+
|
173 |
+
from types import ModuleType
|
174 |
+
|
175 |
+
compiled = compile(text, filename, 'exec')
|
176 |
+
module = ModuleType(script.filename)
|
177 |
+
exec(compiled, module.__dict__)
|
178 |
+
|
179 |
+
for key, script_class in module.__dict__.items():
|
180 |
+
if type(script_class) == type and issubclass(script_class, Script):
|
181 |
+
self.scripts[si] = script_class()
|
182 |
+
self.scripts[si].filename = filename
|
183 |
+
self.scripts[si].args_from = args_from
|
184 |
+
self.scripts[si].args_to = args_to
|
185 |
+
|
186 |
+
scripts_txt2img = ScriptRunner()
|
187 |
+
scripts_img2img = ScriptRunner()
|
188 |
+
|
189 |
+
def reload_script_body_only():
|
190 |
+
scripts_txt2img.reload_sources()
|
191 |
+
scripts_img2img.reload_sources()
|
192 |
+
|
193 |
+
|
194 |
+
def reload_scripts(basedir):
|
195 |
+
global scripts_txt2img, scripts_img2img
|
196 |
+
|
197 |
+
scripts_data.clear()
|
198 |
+
load_scripts(basedir)
|
199 |
+
|
200 |
+
scripts_txt2img = ScriptRunner()
|
201 |
+
scripts_img2img = ScriptRunner()
|
modules/scunet_model.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import PIL.Image
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from basicsr.utils.download_util import load_file_from_url
|
9 |
+
|
10 |
+
import modules.upscaler
|
11 |
+
from modules import devices, modelloader
|
12 |
+
from modules.paths import models_path
|
13 |
+
from modules.scunet_model_arch import SCUNet as net
|
14 |
+
|
15 |
+
|
16 |
+
class UpscalerScuNET(modules.upscaler.Upscaler):
|
17 |
+
def __init__(self, dirname):
|
18 |
+
self.name = "ScuNET"
|
19 |
+
self.model_path = os.path.join(models_path, self.name)
|
20 |
+
self.model_name = "ScuNET GAN"
|
21 |
+
self.model_name2 = "ScuNET PSNR"
|
22 |
+
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
|
23 |
+
self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
|
24 |
+
self.user_path = dirname
|
25 |
+
super().__init__()
|
26 |
+
model_paths = self.find_models(ext_filter=[".pth"])
|
27 |
+
scalers = []
|
28 |
+
add_model2 = True
|
29 |
+
for file in model_paths:
|
30 |
+
if "http" in file:
|
31 |
+
name = self.model_name
|
32 |
+
else:
|
33 |
+
name = modelloader.friendly_name(file)
|
34 |
+
if name == self.model_name2 or file == self.model_url2:
|
35 |
+
add_model2 = False
|
36 |
+
try:
|
37 |
+
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
38 |
+
scalers.append(scaler_data)
|
39 |
+
except Exception:
|
40 |
+
print(f"Error loading ScuNET model: {file}", file=sys.stderr)
|
41 |
+
print(traceback.format_exc(), file=sys.stderr)
|
42 |
+
if add_model2:
|
43 |
+
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
|
44 |
+
scalers.append(scaler_data2)
|
45 |
+
self.scalers = scalers
|
46 |
+
|
47 |
+
def do_upscale(self, img: PIL.Image, selected_file):
|
48 |
+
torch.cuda.empty_cache()
|
49 |
+
|
50 |
+
model = self.load_model(selected_file)
|
51 |
+
if model is None:
|
52 |
+
return img
|
53 |
+
|
54 |
+
device = devices.device_scunet
|
55 |
+
img = np.array(img)
|
56 |
+
img = img[:, :, ::-1]
|
57 |
+
img = np.moveaxis(img, 2, 0) / 255
|
58 |
+
img = torch.from_numpy(img).float()
|
59 |
+
img = img.unsqueeze(0).to(device)
|
60 |
+
|
61 |
+
img = img.to(device)
|
62 |
+
with torch.no_grad():
|
63 |
+
output = model(img)
|
64 |
+
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
65 |
+
output = 255. * np.moveaxis(output, 0, 2)
|
66 |
+
output = output.astype(np.uint8)
|
67 |
+
output = output[:, :, ::-1]
|
68 |
+
torch.cuda.empty_cache()
|
69 |
+
return PIL.Image.fromarray(output, 'RGB')
|
70 |
+
|
71 |
+
def load_model(self, path: str):
|
72 |
+
device = devices.device_scunet
|
73 |
+
if "http" in path:
|
74 |
+
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
75 |
+
progress=True)
|
76 |
+
else:
|
77 |
+
filename = path
|
78 |
+
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
|
79 |
+
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
|
80 |
+
return None
|
81 |
+
|
82 |
+
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
83 |
+
model.load_state_dict(torch.load(filename), strict=True)
|
84 |
+
model.eval()
|
85 |
+
for k, v in model.named_parameters():
|
86 |
+
v.requires_grad = False
|
87 |
+
model = model.to(device)
|
88 |
+
|
89 |
+
return model
|
90 |
+
|
modules/scunet_model_arch.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
from einops.layers.torch import Rearrange
|
7 |
+
from timm.models.layers import trunc_normal_, DropPath
|
8 |
+
|
9 |
+
|
10 |
+
class WMSA(nn.Module):
|
11 |
+
""" Self-attention module in Swin Transformer
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
15 |
+
super(WMSA, self).__init__()
|
16 |
+
self.input_dim = input_dim
|
17 |
+
self.output_dim = output_dim
|
18 |
+
self.head_dim = head_dim
|
19 |
+
self.scale = self.head_dim ** -0.5
|
20 |
+
self.n_heads = input_dim // head_dim
|
21 |
+
self.window_size = window_size
|
22 |
+
self.type = type
|
23 |
+
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
|
24 |
+
|
25 |
+
self.relative_position_params = nn.Parameter(
|
26 |
+
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
|
27 |
+
|
28 |
+
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
29 |
+
|
30 |
+
trunc_normal_(self.relative_position_params, std=.02)
|
31 |
+
self.relative_position_params = torch.nn.Parameter(
|
32 |
+
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
|
33 |
+
2).transpose(
|
34 |
+
0, 1))
|
35 |
+
|
36 |
+
def generate_mask(self, h, w, p, shift):
|
37 |
+
""" generating the mask of SW-MSA
|
38 |
+
Args:
|
39 |
+
shift: shift parameters in CyclicShift.
|
40 |
+
Returns:
|
41 |
+
attn_mask: should be (1 1 w p p),
|
42 |
+
"""
|
43 |
+
# supporting sqaure.
|
44 |
+
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
45 |
+
if self.type == 'W':
|
46 |
+
return attn_mask
|
47 |
+
|
48 |
+
s = p - shift
|
49 |
+
attn_mask[-1, :, :s, :, s:, :] = True
|
50 |
+
attn_mask[-1, :, s:, :, :s, :] = True
|
51 |
+
attn_mask[:, -1, :, :s, :, s:] = True
|
52 |
+
attn_mask[:, -1, :, s:, :, :s] = True
|
53 |
+
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
|
54 |
+
return attn_mask
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
""" Forward pass of Window Multi-head Self-attention module.
|
58 |
+
Args:
|
59 |
+
x: input tensor with shape of [b h w c];
|
60 |
+
attn_mask: attention mask, fill -inf where the value is True;
|
61 |
+
Returns:
|
62 |
+
output: tensor shape [b h w c]
|
63 |
+
"""
|
64 |
+
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
65 |
+
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
66 |
+
h_windows = x.size(1)
|
67 |
+
w_windows = x.size(2)
|
68 |
+
# sqaure validation
|
69 |
+
# assert h_windows == w_windows
|
70 |
+
|
71 |
+
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
72 |
+
qkv = self.embedding_layer(x)
|
73 |
+
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
|
74 |
+
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
|
75 |
+
# Adding learnable relative embedding
|
76 |
+
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
|
77 |
+
# Using Attn Mask to distinguish different subwindows.
|
78 |
+
if self.type != 'W':
|
79 |
+
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
|
80 |
+
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
81 |
+
|
82 |
+
probs = nn.functional.softmax(sim, dim=-1)
|
83 |
+
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
|
84 |
+
output = rearrange(output, 'h b w p c -> b w p (h c)')
|
85 |
+
output = self.linear(output)
|
86 |
+
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
87 |
+
|
88 |
+
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
|
89 |
+
dims=(1, 2))
|
90 |
+
return output
|
91 |
+
|
92 |
+
def relative_embedding(self):
|
93 |
+
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
|
94 |
+
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
|
95 |
+
# negative is allowed
|
96 |
+
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
|
97 |
+
|
98 |
+
|
99 |
+
class Block(nn.Module):
|
100 |
+
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
101 |
+
""" SwinTransformer Block
|
102 |
+
"""
|
103 |
+
super(Block, self).__init__()
|
104 |
+
self.input_dim = input_dim
|
105 |
+
self.output_dim = output_dim
|
106 |
+
assert type in ['W', 'SW']
|
107 |
+
self.type = type
|
108 |
+
if input_resolution <= window_size:
|
109 |
+
self.type = 'W'
|
110 |
+
|
111 |
+
self.ln1 = nn.LayerNorm(input_dim)
|
112 |
+
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
113 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
114 |
+
self.ln2 = nn.LayerNorm(input_dim)
|
115 |
+
self.mlp = nn.Sequential(
|
116 |
+
nn.Linear(input_dim, 4 * input_dim),
|
117 |
+
nn.GELU(),
|
118 |
+
nn.Linear(4 * input_dim, output_dim),
|
119 |
+
)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
x = x + self.drop_path(self.msa(self.ln1(x)))
|
123 |
+
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
124 |
+
return x
|
125 |
+
|
126 |
+
|
127 |
+
class ConvTransBlock(nn.Module):
|
128 |
+
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
129 |
+
""" SwinTransformer and Conv Block
|
130 |
+
"""
|
131 |
+
super(ConvTransBlock, self).__init__()
|
132 |
+
self.conv_dim = conv_dim
|
133 |
+
self.trans_dim = trans_dim
|
134 |
+
self.head_dim = head_dim
|
135 |
+
self.window_size = window_size
|
136 |
+
self.drop_path = drop_path
|
137 |
+
self.type = type
|
138 |
+
self.input_resolution = input_resolution
|
139 |
+
|
140 |
+
assert self.type in ['W', 'SW']
|
141 |
+
if self.input_resolution <= self.window_size:
|
142 |
+
self.type = 'W'
|
143 |
+
|
144 |
+
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
|
145 |
+
self.type, self.input_resolution)
|
146 |
+
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
147 |
+
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
148 |
+
|
149 |
+
self.conv_block = nn.Sequential(
|
150 |
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
151 |
+
nn.ReLU(True),
|
152 |
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
|
153 |
+
)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
|
157 |
+
conv_x = self.conv_block(conv_x) + conv_x
|
158 |
+
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
|
159 |
+
trans_x = self.trans_block(trans_x)
|
160 |
+
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
|
161 |
+
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
162 |
+
x = x + res
|
163 |
+
|
164 |
+
return x
|
165 |
+
|
166 |
+
|
167 |
+
class SCUNet(nn.Module):
|
168 |
+
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
|
169 |
+
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
|
170 |
+
super(SCUNet, self).__init__()
|
171 |
+
if config is None:
|
172 |
+
config = [2, 2, 2, 2, 2, 2, 2]
|
173 |
+
self.config = config
|
174 |
+
self.dim = dim
|
175 |
+
self.head_dim = 32
|
176 |
+
self.window_size = 8
|
177 |
+
|
178 |
+
# drop path rate for each layer
|
179 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
180 |
+
|
181 |
+
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
182 |
+
|
183 |
+
begin = 0
|
184 |
+
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
185 |
+
'W' if not i % 2 else 'SW', input_resolution)
|
186 |
+
for i in range(config[0])] + \
|
187 |
+
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
|
188 |
+
|
189 |
+
begin += config[0]
|
190 |
+
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
191 |
+
'W' if not i % 2 else 'SW', input_resolution // 2)
|
192 |
+
for i in range(config[1])] + \
|
193 |
+
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
|
194 |
+
|
195 |
+
begin += config[1]
|
196 |
+
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
197 |
+
'W' if not i % 2 else 'SW', input_resolution // 4)
|
198 |
+
for i in range(config[2])] + \
|
199 |
+
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
|
200 |
+
|
201 |
+
begin += config[2]
|
202 |
+
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
203 |
+
'W' if not i % 2 else 'SW', input_resolution // 8)
|
204 |
+
for i in range(config[3])]
|
205 |
+
|
206 |
+
begin += config[3]
|
207 |
+
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
|
208 |
+
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
209 |
+
'W' if not i % 2 else 'SW', input_resolution // 4)
|
210 |
+
for i in range(config[4])]
|
211 |
+
|
212 |
+
begin += config[4]
|
213 |
+
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
|
214 |
+
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
215 |
+
'W' if not i % 2 else 'SW', input_resolution // 2)
|
216 |
+
for i in range(config[5])]
|
217 |
+
|
218 |
+
begin += config[5]
|
219 |
+
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
|
220 |
+
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
221 |
+
'W' if not i % 2 else 'SW', input_resolution)
|
222 |
+
for i in range(config[6])]
|
223 |
+
|
224 |
+
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
225 |
+
|
226 |
+
self.m_head = nn.Sequential(*self.m_head)
|
227 |
+
self.m_down1 = nn.Sequential(*self.m_down1)
|
228 |
+
self.m_down2 = nn.Sequential(*self.m_down2)
|
229 |
+
self.m_down3 = nn.Sequential(*self.m_down3)
|
230 |
+
self.m_body = nn.Sequential(*self.m_body)
|
231 |
+
self.m_up3 = nn.Sequential(*self.m_up3)
|
232 |
+
self.m_up2 = nn.Sequential(*self.m_up2)
|
233 |
+
self.m_up1 = nn.Sequential(*self.m_up1)
|
234 |
+
self.m_tail = nn.Sequential(*self.m_tail)
|
235 |
+
# self.apply(self._init_weights)
|
236 |
+
|
237 |
+
def forward(self, x0):
|
238 |
+
|
239 |
+
h, w = x0.size()[-2:]
|
240 |
+
paddingBottom = int(np.ceil(h / 64) * 64 - h)
|
241 |
+
paddingRight = int(np.ceil(w / 64) * 64 - w)
|
242 |
+
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
243 |
+
|
244 |
+
x1 = self.m_head(x0)
|
245 |
+
x2 = self.m_down1(x1)
|
246 |
+
x3 = self.m_down2(x2)
|
247 |
+
x4 = self.m_down3(x3)
|
248 |
+
x = self.m_body(x4)
|
249 |
+
x = self.m_up3(x + x4)
|
250 |
+
x = self.m_up2(x + x3)
|
251 |
+
x = self.m_up1(x + x2)
|
252 |
+
x = self.m_tail(x + x1)
|
253 |
+
|
254 |
+
x = x[..., :h, :w]
|
255 |
+
|
256 |
+
return x
|
257 |
+
|
258 |
+
def _init_weights(self, m):
|
259 |
+
if isinstance(m, nn.Linear):
|
260 |
+
trunc_normal_(m.weight, std=.02)
|
261 |
+
if m.bias is not None:
|
262 |
+
nn.init.constant_(m.bias, 0)
|
263 |
+
elif isinstance(m, nn.LayerNorm):
|
264 |
+
nn.init.constant_(m.bias, 0)
|
265 |
+
nn.init.constant_(m.weight, 1.0)
|
modules/sd_hijack.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from torch import einsum
|
8 |
+
from torch.nn.functional import silu
|
9 |
+
|
10 |
+
import modules.textual_inversion.textual_inversion
|
11 |
+
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
|
12 |
+
from modules.shared import opts, device, cmd_opts
|
13 |
+
|
14 |
+
import ldm.modules.attention
|
15 |
+
import ldm.modules.diffusionmodules.model
|
16 |
+
|
17 |
+
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
18 |
+
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
19 |
+
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
20 |
+
|
21 |
+
|
22 |
+
def apply_optimizations():
|
23 |
+
undo_optimizations()
|
24 |
+
|
25 |
+
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
26 |
+
|
27 |
+
if cmd_opts.opt_split_attention_v1:
|
28 |
+
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
29 |
+
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
30 |
+
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
31 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
32 |
+
|
33 |
+
|
34 |
+
def undo_optimizations():
|
35 |
+
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
36 |
+
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
37 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
38 |
+
|
39 |
+
|
40 |
+
class StableDiffusionModelHijack:
|
41 |
+
fixes = None
|
42 |
+
comments = []
|
43 |
+
layers = None
|
44 |
+
circular_enabled = False
|
45 |
+
clip = None
|
46 |
+
|
47 |
+
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
48 |
+
|
49 |
+
def hijack(self, m):
|
50 |
+
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
51 |
+
|
52 |
+
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
53 |
+
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
54 |
+
|
55 |
+
self.clip = m.cond_stage_model
|
56 |
+
|
57 |
+
apply_optimizations()
|
58 |
+
|
59 |
+
def flatten(el):
|
60 |
+
flattened = [flatten(children) for children in el.children()]
|
61 |
+
res = [el]
|
62 |
+
for c in flattened:
|
63 |
+
res += c
|
64 |
+
return res
|
65 |
+
|
66 |
+
self.layers = flatten(m)
|
67 |
+
|
68 |
+
def undo_hijack(self, m):
|
69 |
+
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
|
70 |
+
m.cond_stage_model = m.cond_stage_model.wrapped
|
71 |
+
|
72 |
+
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
73 |
+
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
74 |
+
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
75 |
+
|
76 |
+
def apply_circular(self, enable):
|
77 |
+
if self.circular_enabled == enable:
|
78 |
+
return
|
79 |
+
|
80 |
+
self.circular_enabled = enable
|
81 |
+
|
82 |
+
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
83 |
+
layer.padding_mode = 'circular' if enable else 'zeros'
|
84 |
+
|
85 |
+
def tokenize(self, text):
|
86 |
+
max_length = self.clip.max_length - 2
|
87 |
+
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
88 |
+
return remade_batch_tokens[0], token_count, max_length
|
89 |
+
|
90 |
+
|
91 |
+
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
92 |
+
def __init__(self, wrapped, hijack):
|
93 |
+
super().__init__()
|
94 |
+
self.wrapped = wrapped
|
95 |
+
self.hijack: StableDiffusionModelHijack = hijack
|
96 |
+
self.tokenizer = wrapped.tokenizer
|
97 |
+
self.max_length = wrapped.max_length
|
98 |
+
self.token_mults = {}
|
99 |
+
|
100 |
+
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
101 |
+
for text, ident in tokens_with_parens:
|
102 |
+
mult = 1.0
|
103 |
+
for c in text:
|
104 |
+
if c == '[':
|
105 |
+
mult /= 1.1
|
106 |
+
if c == ']':
|
107 |
+
mult *= 1.1
|
108 |
+
if c == '(':
|
109 |
+
mult *= 1.1
|
110 |
+
if c == ')':
|
111 |
+
mult /= 1.1
|
112 |
+
|
113 |
+
if mult != 1.0:
|
114 |
+
self.token_mults[ident] = mult
|
115 |
+
|
116 |
+
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
117 |
+
id_start = self.wrapped.tokenizer.bos_token_id
|
118 |
+
id_end = self.wrapped.tokenizer.eos_token_id
|
119 |
+
maxlen = self.wrapped.max_length
|
120 |
+
|
121 |
+
if opts.enable_emphasis:
|
122 |
+
parsed = prompt_parser.parse_prompt_attention(line)
|
123 |
+
else:
|
124 |
+
parsed = [[line, 1.0]]
|
125 |
+
|
126 |
+
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
|
127 |
+
|
128 |
+
fixes = []
|
129 |
+
remade_tokens = []
|
130 |
+
multipliers = []
|
131 |
+
|
132 |
+
for tokens, (text, weight) in zip(tokenized, parsed):
|
133 |
+
i = 0
|
134 |
+
while i < len(tokens):
|
135 |
+
token = tokens[i]
|
136 |
+
|
137 |
+
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
138 |
+
|
139 |
+
if embedding is None:
|
140 |
+
remade_tokens.append(token)
|
141 |
+
multipliers.append(weight)
|
142 |
+
i += 1
|
143 |
+
else:
|
144 |
+
emb_len = int(embedding.vec.shape[0])
|
145 |
+
fixes.append((len(remade_tokens), embedding))
|
146 |
+
remade_tokens += [0] * emb_len
|
147 |
+
multipliers += [weight] * emb_len
|
148 |
+
used_custom_terms.append((embedding.name, embedding.checksum()))
|
149 |
+
i += embedding_length_in_tokens
|
150 |
+
|
151 |
+
if len(remade_tokens) > maxlen - 2:
|
152 |
+
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
153 |
+
ovf = remade_tokens[maxlen - 2:]
|
154 |
+
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
155 |
+
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
156 |
+
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
157 |
+
|
158 |
+
token_count = len(remade_tokens)
|
159 |
+
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
160 |
+
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
161 |
+
|
162 |
+
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
163 |
+
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
164 |
+
|
165 |
+
return remade_tokens, fixes, multipliers, token_count
|
166 |
+
|
167 |
+
def process_text(self, texts):
|
168 |
+
used_custom_terms = []
|
169 |
+
remade_batch_tokens = []
|
170 |
+
hijack_comments = []
|
171 |
+
hijack_fixes = []
|
172 |
+
token_count = 0
|
173 |
+
|
174 |
+
cache = {}
|
175 |
+
batch_multipliers = []
|
176 |
+
for line in texts:
|
177 |
+
if line in cache:
|
178 |
+
remade_tokens, fixes, multipliers = cache[line]
|
179 |
+
else:
|
180 |
+
remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
181 |
+
|
182 |
+
cache[line] = (remade_tokens, fixes, multipliers)
|
183 |
+
|
184 |
+
remade_batch_tokens.append(remade_tokens)
|
185 |
+
hijack_fixes.append(fixes)
|
186 |
+
batch_multipliers.append(multipliers)
|
187 |
+
|
188 |
+
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
189 |
+
|
190 |
+
|
191 |
+
def process_text_old(self, text):
|
192 |
+
id_start = self.wrapped.tokenizer.bos_token_id
|
193 |
+
id_end = self.wrapped.tokenizer.eos_token_id
|
194 |
+
maxlen = self.wrapped.max_length
|
195 |
+
used_custom_terms = []
|
196 |
+
remade_batch_tokens = []
|
197 |
+
overflowing_words = []
|
198 |
+
hijack_comments = []
|
199 |
+
hijack_fixes = []
|
200 |
+
token_count = 0
|
201 |
+
|
202 |
+
cache = {}
|
203 |
+
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
204 |
+
batch_multipliers = []
|
205 |
+
for tokens in batch_tokens:
|
206 |
+
tuple_tokens = tuple(tokens)
|
207 |
+
|
208 |
+
if tuple_tokens in cache:
|
209 |
+
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
210 |
+
else:
|
211 |
+
fixes = []
|
212 |
+
remade_tokens = []
|
213 |
+
multipliers = []
|
214 |
+
mult = 1.0
|
215 |
+
|
216 |
+
i = 0
|
217 |
+
while i < len(tokens):
|
218 |
+
token = tokens[i]
|
219 |
+
|
220 |
+
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
221 |
+
|
222 |
+
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
223 |
+
if mult_change is not None:
|
224 |
+
mult *= mult_change
|
225 |
+
i += 1
|
226 |
+
elif embedding is None:
|
227 |
+
remade_tokens.append(token)
|
228 |
+
multipliers.append(mult)
|
229 |
+
i += 1
|
230 |
+
else:
|
231 |
+
emb_len = int(embedding.vec.shape[0])
|
232 |
+
fixes.append((len(remade_tokens), embedding))
|
233 |
+
remade_tokens += [0] * emb_len
|
234 |
+
multipliers += [mult] * emb_len
|
235 |
+
used_custom_terms.append((embedding.name, embedding.checksum()))
|
236 |
+
i += embedding_length_in_tokens
|
237 |
+
|
238 |
+
if len(remade_tokens) > maxlen - 2:
|
239 |
+
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
240 |
+
ovf = remade_tokens[maxlen - 2:]
|
241 |
+
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
242 |
+
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
243 |
+
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
244 |
+
|
245 |
+
token_count = len(remade_tokens)
|
246 |
+
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
247 |
+
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
248 |
+
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
249 |
+
|
250 |
+
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
251 |
+
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
252 |
+
|
253 |
+
remade_batch_tokens.append(remade_tokens)
|
254 |
+
hijack_fixes.append(fixes)
|
255 |
+
batch_multipliers.append(multipliers)
|
256 |
+
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
257 |
+
|
258 |
+
def forward(self, text):
|
259 |
+
|
260 |
+
if opts.use_old_emphasis_implementation:
|
261 |
+
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
262 |
+
else:
|
263 |
+
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
264 |
+
|
265 |
+
self.hijack.fixes = hijack_fixes
|
266 |
+
self.hijack.comments = hijack_comments
|
267 |
+
|
268 |
+
if len(used_custom_terms) > 0:
|
269 |
+
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
270 |
+
|
271 |
+
tokens = torch.asarray(remade_batch_tokens).to(device)
|
272 |
+
outputs = self.wrapped.transformer(input_ids=tokens)
|
273 |
+
z = outputs.last_hidden_state
|
274 |
+
|
275 |
+
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
276 |
+
batch_multipliers = torch.asarray(batch_multipliers).to(device)
|
277 |
+
original_mean = z.mean()
|
278 |
+
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
279 |
+
new_mean = z.mean()
|
280 |
+
z *= original_mean / new_mean
|
281 |
+
|
282 |
+
return z
|
283 |
+
|
284 |
+
|
285 |
+
class EmbeddingsWithFixes(torch.nn.Module):
|
286 |
+
def __init__(self, wrapped, embeddings):
|
287 |
+
super().__init__()
|
288 |
+
self.wrapped = wrapped
|
289 |
+
self.embeddings = embeddings
|
290 |
+
|
291 |
+
def forward(self, input_ids):
|
292 |
+
batch_fixes = self.embeddings.fixes
|
293 |
+
self.embeddings.fixes = None
|
294 |
+
|
295 |
+
inputs_embeds = self.wrapped(input_ids)
|
296 |
+
|
297 |
+
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
298 |
+
return inputs_embeds
|
299 |
+
|
300 |
+
vecs = []
|
301 |
+
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
302 |
+
for offset, embedding in fixes:
|
303 |
+
emb = embedding.vec
|
304 |
+
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
305 |
+
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
|
306 |
+
|
307 |
+
vecs.append(tensor)
|
308 |
+
|
309 |
+
return torch.stack(vecs)
|
310 |
+
|
311 |
+
|
312 |
+
def add_circular_option_to_conv_2d():
|
313 |
+
conv2d_constructor = torch.nn.Conv2d.__init__
|
314 |
+
|
315 |
+
def conv2d_constructor_circular(self, *args, **kwargs):
|
316 |
+
return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
317 |
+
|
318 |
+
torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
319 |
+
|
320 |
+
|
321 |
+
model_hijack = StableDiffusionModelHijack()
|
modules/sd_hijack_optimizations.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import einsum
|
4 |
+
|
5 |
+
from ldm.util import default
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from modules import shared
|
9 |
+
|
10 |
+
|
11 |
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
12 |
+
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
13 |
+
h = self.heads
|
14 |
+
|
15 |
+
q = self.to_q(x)
|
16 |
+
context = default(context, x)
|
17 |
+
k = self.to_k(context)
|
18 |
+
v = self.to_v(context)
|
19 |
+
del context, x
|
20 |
+
|
21 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
22 |
+
|
23 |
+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
24 |
+
for i in range(0, q.shape[0], 2):
|
25 |
+
end = i + 2
|
26 |
+
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
27 |
+
s1 *= self.scale
|
28 |
+
|
29 |
+
s2 = s1.softmax(dim=-1)
|
30 |
+
del s1
|
31 |
+
|
32 |
+
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
33 |
+
del s2
|
34 |
+
|
35 |
+
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
36 |
+
del r1
|
37 |
+
|
38 |
+
return self.to_out(r2)
|
39 |
+
|
40 |
+
|
41 |
+
# taken from https://github.com/Doggettx/stable-diffusion
|
42 |
+
def split_cross_attention_forward(self, x, context=None, mask=None):
|
43 |
+
h = self.heads
|
44 |
+
|
45 |
+
q_in = self.to_q(x)
|
46 |
+
context = default(context, x)
|
47 |
+
|
48 |
+
hypernetwork = shared.selected_hypernetwork()
|
49 |
+
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
50 |
+
|
51 |
+
if hypernetwork_layers is not None:
|
52 |
+
k_in = self.to_k(hypernetwork_layers[0](context))
|
53 |
+
v_in = self.to_v(hypernetwork_layers[1](context))
|
54 |
+
else:
|
55 |
+
k_in = self.to_k(context)
|
56 |
+
v_in = self.to_v(context)
|
57 |
+
|
58 |
+
k_in *= self.scale
|
59 |
+
|
60 |
+
del context, x
|
61 |
+
|
62 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
63 |
+
del q_in, k_in, v_in
|
64 |
+
|
65 |
+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
66 |
+
|
67 |
+
stats = torch.cuda.memory_stats(q.device)
|
68 |
+
mem_active = stats['active_bytes.all.current']
|
69 |
+
mem_reserved = stats['reserved_bytes.all.current']
|
70 |
+
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
71 |
+
mem_free_torch = mem_reserved - mem_active
|
72 |
+
mem_free_total = mem_free_cuda + mem_free_torch
|
73 |
+
|
74 |
+
gb = 1024 ** 3
|
75 |
+
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
76 |
+
modifier = 3 if q.element_size() == 2 else 2.5
|
77 |
+
mem_required = tensor_size * modifier
|
78 |
+
steps = 1
|
79 |
+
|
80 |
+
if mem_required > mem_free_total:
|
81 |
+
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
82 |
+
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
83 |
+
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
84 |
+
|
85 |
+
if steps > 64:
|
86 |
+
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
87 |
+
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
88 |
+
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
89 |
+
|
90 |
+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
91 |
+
for i in range(0, q.shape[1], slice_size):
|
92 |
+
end = i + slice_size
|
93 |
+
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
94 |
+
|
95 |
+
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
96 |
+
del s1
|
97 |
+
|
98 |
+
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
99 |
+
del s2
|
100 |
+
|
101 |
+
del q, k, v
|
102 |
+
|
103 |
+
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
104 |
+
del r1
|
105 |
+
|
106 |
+
return self.to_out(r2)
|
107 |
+
|
108 |
+
def cross_attention_attnblock_forward(self, x):
|
109 |
+
h_ = x
|
110 |
+
h_ = self.norm(h_)
|
111 |
+
q1 = self.q(h_)
|
112 |
+
k1 = self.k(h_)
|
113 |
+
v = self.v(h_)
|
114 |
+
|
115 |
+
# compute attention
|
116 |
+
b, c, h, w = q1.shape
|
117 |
+
|
118 |
+
q2 = q1.reshape(b, c, h*w)
|
119 |
+
del q1
|
120 |
+
|
121 |
+
q = q2.permute(0, 2, 1) # b,hw,c
|
122 |
+
del q2
|
123 |
+
|
124 |
+
k = k1.reshape(b, c, h*w) # b,c,hw
|
125 |
+
del k1
|
126 |
+
|
127 |
+
h_ = torch.zeros_like(k, device=q.device)
|
128 |
+
|
129 |
+
stats = torch.cuda.memory_stats(q.device)
|
130 |
+
mem_active = stats['active_bytes.all.current']
|
131 |
+
mem_reserved = stats['reserved_bytes.all.current']
|
132 |
+
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
133 |
+
mem_free_torch = mem_reserved - mem_active
|
134 |
+
mem_free_total = mem_free_cuda + mem_free_torch
|
135 |
+
|
136 |
+
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
137 |
+
mem_required = tensor_size * 2.5
|
138 |
+
steps = 1
|
139 |
+
|
140 |
+
if mem_required > mem_free_total:
|
141 |
+
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
142 |
+
|
143 |
+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
144 |
+
for i in range(0, q.shape[1], slice_size):
|
145 |
+
end = i + slice_size
|
146 |
+
|
147 |
+
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
148 |
+
w2 = w1 * (int(c)**(-0.5))
|
149 |
+
del w1
|
150 |
+
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
151 |
+
del w2
|
152 |
+
|
153 |
+
# attend to values
|
154 |
+
v1 = v.reshape(b, c, h*w)
|
155 |
+
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
156 |
+
del w3
|
157 |
+
|
158 |
+
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
159 |
+
del v1, w4
|
160 |
+
|
161 |
+
h2 = h_.reshape(b, c, h, w)
|
162 |
+
del h_
|
163 |
+
|
164 |
+
h3 = self.proj_out(h2)
|
165 |
+
del h2
|
166 |
+
|
167 |
+
h3 += x
|
168 |
+
|
169 |
+
return h3
|
modules/sd_models.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os.path
|
3 |
+
import sys
|
4 |
+
from collections import namedtuple
|
5 |
+
import torch
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
|
8 |
+
|
9 |
+
from ldm.util import instantiate_from_config
|
10 |
+
|
11 |
+
from modules import shared, modelloader, devices
|
12 |
+
from modules.paths import models_path
|
13 |
+
|
14 |
+
model_dir = "Stable-diffusion"
|
15 |
+
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
16 |
+
|
17 |
+
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
18 |
+
checkpoints_list = {}
|
19 |
+
|
20 |
+
try:
|
21 |
+
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
22 |
+
|
23 |
+
from transformers import logging
|
24 |
+
|
25 |
+
logging.set_verbosity_error()
|
26 |
+
except Exception:
|
27 |
+
pass
|
28 |
+
|
29 |
+
|
30 |
+
def setup_model():
|
31 |
+
if not os.path.exists(model_path):
|
32 |
+
os.makedirs(model_path)
|
33 |
+
|
34 |
+
list_models()
|
35 |
+
|
36 |
+
|
37 |
+
def checkpoint_tiles():
|
38 |
+
return sorted([x.title for x in checkpoints_list.values()])
|
39 |
+
|
40 |
+
|
41 |
+
def list_models():
|
42 |
+
checkpoints_list.clear()
|
43 |
+
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
|
44 |
+
|
45 |
+
def modeltitle(path, shorthash):
|
46 |
+
abspath = os.path.abspath(path)
|
47 |
+
|
48 |
+
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
49 |
+
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
50 |
+
elif abspath.startswith(model_path):
|
51 |
+
name = abspath.replace(model_path, '')
|
52 |
+
else:
|
53 |
+
name = os.path.basename(path)
|
54 |
+
|
55 |
+
if name.startswith("\\") or name.startswith("/"):
|
56 |
+
name = name[1:]
|
57 |
+
|
58 |
+
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
59 |
+
|
60 |
+
return f'{name} [{shorthash}]', shortname
|
61 |
+
|
62 |
+
cmd_ckpt = shared.cmd_opts.ckpt
|
63 |
+
if os.path.exists(cmd_ckpt):
|
64 |
+
h = model_hash(cmd_ckpt)
|
65 |
+
title, short_model_name = modeltitle(cmd_ckpt, h)
|
66 |
+
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
67 |
+
shared.opts.data['sd_model_checkpoint'] = title
|
68 |
+
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
69 |
+
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
70 |
+
for filename in model_list:
|
71 |
+
h = model_hash(filename)
|
72 |
+
title, short_model_name = modeltitle(filename, h)
|
73 |
+
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
74 |
+
|
75 |
+
|
76 |
+
def get_closet_checkpoint_match(searchString):
|
77 |
+
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
|
78 |
+
if len(applicable) > 0:
|
79 |
+
return applicable[0]
|
80 |
+
return None
|
81 |
+
|
82 |
+
|
83 |
+
def model_hash(filename):
|
84 |
+
try:
|
85 |
+
with open(filename, "rb") as file:
|
86 |
+
import hashlib
|
87 |
+
m = hashlib.sha256()
|
88 |
+
|
89 |
+
file.seek(0x100000)
|
90 |
+
m.update(file.read(0x10000))
|
91 |
+
return m.hexdigest()[0:8]
|
92 |
+
except FileNotFoundError:
|
93 |
+
return 'NOFILE'
|
94 |
+
|
95 |
+
|
96 |
+
def select_checkpoint():
|
97 |
+
model_checkpoint = shared.opts.sd_model_checkpoint
|
98 |
+
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
|
99 |
+
if checkpoint_info is not None:
|
100 |
+
return checkpoint_info
|
101 |
+
|
102 |
+
if len(checkpoints_list) == 0:
|
103 |
+
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
104 |
+
if shared.cmd_opts.ckpt is not None:
|
105 |
+
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
|
106 |
+
print(f" - directory {model_path}", file=sys.stderr)
|
107 |
+
if shared.cmd_opts.ckpt_dir is not None:
|
108 |
+
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
109 |
+
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
110 |
+
exit(1)
|
111 |
+
|
112 |
+
checkpoint_info = next(iter(checkpoints_list.values()))
|
113 |
+
if model_checkpoint is not None:
|
114 |
+
print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
|
115 |
+
|
116 |
+
return checkpoint_info
|
117 |
+
|
118 |
+
|
119 |
+
def load_model_weights(model, checkpoint_file, sd_model_hash):
|
120 |
+
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
121 |
+
|
122 |
+
pl_sd = torch.load(checkpoint_file, map_location="cpu")
|
123 |
+
if "global_step" in pl_sd:
|
124 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
125 |
+
sd = pl_sd["state_dict"]
|
126 |
+
|
127 |
+
model.load_state_dict(sd, strict=False)
|
128 |
+
|
129 |
+
if shared.cmd_opts.opt_channelslast:
|
130 |
+
model.to(memory_format=torch.channels_last)
|
131 |
+
|
132 |
+
if not shared.cmd_opts.no_half:
|
133 |
+
model.half()
|
134 |
+
|
135 |
+
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
136 |
+
|
137 |
+
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
|
138 |
+
if os.path.exists(vae_file):
|
139 |
+
print(f"Loading VAE weights from: {vae_file}")
|
140 |
+
vae_ckpt = torch.load(vae_file, map_location="cpu")
|
141 |
+
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
142 |
+
|
143 |
+
model.first_stage_model.load_state_dict(vae_dict)
|
144 |
+
|
145 |
+
model.sd_model_hash = sd_model_hash
|
146 |
+
model.sd_model_checkpint = checkpoint_file
|
147 |
+
|
148 |
+
|
149 |
+
def load_model():
|
150 |
+
from modules import lowvram, sd_hijack
|
151 |
+
checkpoint_info = select_checkpoint()
|
152 |
+
|
153 |
+
sd_config = OmegaConf.load(shared.cmd_opts.config)
|
154 |
+
sd_model = instantiate_from_config(sd_config.model)
|
155 |
+
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
|
156 |
+
|
157 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
158 |
+
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
159 |
+
else:
|
160 |
+
sd_model.to(shared.device)
|
161 |
+
|
162 |
+
sd_hijack.model_hijack.hijack(sd_model)
|
163 |
+
|
164 |
+
sd_model.eval()
|
165 |
+
|
166 |
+
print(f"Model loaded.")
|
167 |
+
return sd_model
|
168 |
+
|
169 |
+
|
170 |
+
def reload_model_weights(sd_model, info=None):
|
171 |
+
from modules import lowvram, devices, sd_hijack
|
172 |
+
checkpoint_info = info or select_checkpoint()
|
173 |
+
|
174 |
+
if sd_model.sd_model_checkpint == checkpoint_info.filename:
|
175 |
+
return
|
176 |
+
|
177 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
178 |
+
lowvram.send_everything_to_cpu()
|
179 |
+
else:
|
180 |
+
sd_model.to(devices.cpu)
|
181 |
+
|
182 |
+
sd_hijack.model_hijack.undo_hijack(sd_model)
|
183 |
+
|
184 |
+
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
|
185 |
+
|
186 |
+
sd_hijack.model_hijack.hijack(sd_model)
|
187 |
+
|
188 |
+
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
189 |
+
sd_model.to(devices.device)
|
190 |
+
|
191 |
+
print(f"Weights loaded.")
|
192 |
+
return sd_model
|
modules/sd_samplers.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import tqdm
|
5 |
+
from PIL import Image
|
6 |
+
import inspect
|
7 |
+
import k_diffusion.sampling
|
8 |
+
import ldm.models.diffusion.ddim
|
9 |
+
import ldm.models.diffusion.plms
|
10 |
+
from modules import prompt_parser
|
11 |
+
|
12 |
+
from modules.shared import opts, cmd_opts, state
|
13 |
+
import modules.shared as shared
|
14 |
+
|
15 |
+
|
16 |
+
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
17 |
+
|
18 |
+
samplers_k_diffusion = [
|
19 |
+
('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
|
20 |
+
('Euler', 'sample_euler', ['k_euler'], {}),
|
21 |
+
('LMS', 'sample_lms', ['k_lms'], {}),
|
22 |
+
('Heun', 'sample_heun', ['k_heun'], {}),
|
23 |
+
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
|
24 |
+
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
|
25 |
+
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
26 |
+
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
27 |
+
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
28 |
+
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
|
29 |
+
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
|
30 |
+
]
|
31 |
+
|
32 |
+
samplers_data_k_diffusion = [
|
33 |
+
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
34 |
+
for label, funcname, aliases, options in samplers_k_diffusion
|
35 |
+
if hasattr(k_diffusion.sampling, funcname)
|
36 |
+
]
|
37 |
+
|
38 |
+
all_samplers = [
|
39 |
+
*samplers_data_k_diffusion,
|
40 |
+
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
41 |
+
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
42 |
+
]
|
43 |
+
|
44 |
+
samplers = []
|
45 |
+
samplers_for_img2img = []
|
46 |
+
|
47 |
+
|
48 |
+
def create_sampler_with_index(list_of_configs, index, model):
|
49 |
+
config = list_of_configs[index]
|
50 |
+
sampler = config.constructor(model)
|
51 |
+
sampler.config = config
|
52 |
+
|
53 |
+
return sampler
|
54 |
+
|
55 |
+
|
56 |
+
def set_samplers():
|
57 |
+
global samplers, samplers_for_img2img
|
58 |
+
|
59 |
+
hidden = set(opts.hide_samplers)
|
60 |
+
hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive'])
|
61 |
+
|
62 |
+
samplers = [x for x in all_samplers if x.name not in hidden]
|
63 |
+
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
64 |
+
|
65 |
+
|
66 |
+
set_samplers()
|
67 |
+
|
68 |
+
sampler_extra_params = {
|
69 |
+
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
70 |
+
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
71 |
+
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
72 |
+
}
|
73 |
+
|
74 |
+
def setup_img2img_steps(p, steps=None):
|
75 |
+
if opts.img2img_fix_steps or steps is not None:
|
76 |
+
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
77 |
+
t_enc = p.steps - 1
|
78 |
+
else:
|
79 |
+
steps = p.steps
|
80 |
+
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
81 |
+
|
82 |
+
return steps, t_enc
|
83 |
+
|
84 |
+
|
85 |
+
def sample_to_image(samples):
|
86 |
+
x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
|
87 |
+
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
88 |
+
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
89 |
+
x_sample = x_sample.astype(np.uint8)
|
90 |
+
return Image.fromarray(x_sample)
|
91 |
+
|
92 |
+
|
93 |
+
def store_latent(decoded):
|
94 |
+
state.current_latent = decoded
|
95 |
+
|
96 |
+
if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
97 |
+
if not shared.parallel_processing_allowed:
|
98 |
+
shared.state.current_image = sample_to_image(decoded)
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
def extended_tdqm(sequence, *args, desc=None, **kwargs):
|
103 |
+
state.sampling_steps = len(sequence)
|
104 |
+
state.sampling_step = 0
|
105 |
+
|
106 |
+
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
107 |
+
|
108 |
+
for x in seq:
|
109 |
+
if state.interrupted:
|
110 |
+
break
|
111 |
+
|
112 |
+
yield x
|
113 |
+
|
114 |
+
state.sampling_step += 1
|
115 |
+
shared.total_tqdm.update()
|
116 |
+
|
117 |
+
|
118 |
+
ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
|
119 |
+
ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
|
120 |
+
|
121 |
+
|
122 |
+
class VanillaStableDiffusionSampler:
|
123 |
+
def __init__(self, constructor, sd_model):
|
124 |
+
self.sampler = constructor(sd_model)
|
125 |
+
self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
|
126 |
+
self.mask = None
|
127 |
+
self.nmask = None
|
128 |
+
self.init_latent = None
|
129 |
+
self.sampler_noises = None
|
130 |
+
self.step = 0
|
131 |
+
self.eta = None
|
132 |
+
self.default_eta = 0.0
|
133 |
+
self.config = None
|
134 |
+
|
135 |
+
def number_of_needed_noises(self, p):
|
136 |
+
return 0
|
137 |
+
|
138 |
+
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
139 |
+
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
140 |
+
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
141 |
+
|
142 |
+
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
143 |
+
cond = tensor
|
144 |
+
|
145 |
+
if self.mask is not None:
|
146 |
+
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
147 |
+
x_dec = img_orig * self.mask + self.nmask * x_dec
|
148 |
+
|
149 |
+
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
150 |
+
|
151 |
+
if self.mask is not None:
|
152 |
+
store_latent(self.init_latent * self.mask + self.nmask * res[1])
|
153 |
+
else:
|
154 |
+
store_latent(res[1])
|
155 |
+
|
156 |
+
self.step += 1
|
157 |
+
return res
|
158 |
+
|
159 |
+
def initialize(self, p):
|
160 |
+
self.eta = p.eta if p.eta is not None else opts.eta_ddim
|
161 |
+
|
162 |
+
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
163 |
+
if hasattr(self.sampler, fieldname):
|
164 |
+
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
165 |
+
|
166 |
+
self.mask = p.mask if hasattr(p, 'mask') else None
|
167 |
+
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
168 |
+
|
169 |
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
170 |
+
steps, t_enc = setup_img2img_steps(p, steps)
|
171 |
+
|
172 |
+
self.initialize(p)
|
173 |
+
|
174 |
+
# existing code fails with cetain step counts, like 9
|
175 |
+
try:
|
176 |
+
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
177 |
+
except Exception:
|
178 |
+
self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
179 |
+
|
180 |
+
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
181 |
+
|
182 |
+
self.init_latent = x
|
183 |
+
self.step = 0
|
184 |
+
|
185 |
+
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
|
186 |
+
|
187 |
+
return samples
|
188 |
+
|
189 |
+
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
190 |
+
self.initialize(p)
|
191 |
+
|
192 |
+
self.init_latent = None
|
193 |
+
self.step = 0
|
194 |
+
|
195 |
+
steps = steps or p.steps
|
196 |
+
|
197 |
+
# existing code fails with cetin step counts, like 9
|
198 |
+
try:
|
199 |
+
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
200 |
+
except Exception:
|
201 |
+
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
202 |
+
|
203 |
+
return samples_ddim
|
204 |
+
|
205 |
+
|
206 |
+
class CFGDenoiser(torch.nn.Module):
|
207 |
+
def __init__(self, model):
|
208 |
+
super().__init__()
|
209 |
+
self.inner_model = model
|
210 |
+
self.mask = None
|
211 |
+
self.nmask = None
|
212 |
+
self.init_latent = None
|
213 |
+
self.step = 0
|
214 |
+
|
215 |
+
def forward(self, x, sigma, uncond, cond, cond_scale):
|
216 |
+
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
217 |
+
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
218 |
+
|
219 |
+
batch_size = len(conds_list)
|
220 |
+
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
221 |
+
|
222 |
+
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
223 |
+
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
224 |
+
cond_in = torch.cat([tensor, uncond])
|
225 |
+
|
226 |
+
if shared.batch_cond_uncond:
|
227 |
+
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
|
228 |
+
else:
|
229 |
+
x_out = torch.zeros_like(x_in)
|
230 |
+
for batch_offset in range(0, x_out.shape[0], batch_size):
|
231 |
+
a = batch_offset
|
232 |
+
b = a + batch_size
|
233 |
+
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
|
234 |
+
|
235 |
+
denoised_uncond = x_out[-batch_size:]
|
236 |
+
denoised = torch.clone(denoised_uncond)
|
237 |
+
|
238 |
+
for i, conds in enumerate(conds_list):
|
239 |
+
for cond_index, weight in conds:
|
240 |
+
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
241 |
+
|
242 |
+
if self.mask is not None:
|
243 |
+
denoised = self.init_latent * self.mask + self.nmask * denoised
|
244 |
+
|
245 |
+
self.step += 1
|
246 |
+
|
247 |
+
return denoised
|
248 |
+
|
249 |
+
|
250 |
+
def extended_trange(sampler, count, *args, **kwargs):
|
251 |
+
state.sampling_steps = count
|
252 |
+
state.sampling_step = 0
|
253 |
+
|
254 |
+
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
255 |
+
|
256 |
+
for x in seq:
|
257 |
+
if state.interrupted:
|
258 |
+
break
|
259 |
+
|
260 |
+
if sampler.stop_at is not None and x > sampler.stop_at:
|
261 |
+
break
|
262 |
+
|
263 |
+
yield x
|
264 |
+
|
265 |
+
state.sampling_step += 1
|
266 |
+
shared.total_tqdm.update()
|
267 |
+
|
268 |
+
|
269 |
+
class TorchHijack:
|
270 |
+
def __init__(self, kdiff_sampler):
|
271 |
+
self.kdiff_sampler = kdiff_sampler
|
272 |
+
|
273 |
+
def __getattr__(self, item):
|
274 |
+
if item == 'randn_like':
|
275 |
+
return self.kdiff_sampler.randn_like
|
276 |
+
|
277 |
+
if hasattr(torch, item):
|
278 |
+
return getattr(torch, item)
|
279 |
+
|
280 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
281 |
+
|
282 |
+
|
283 |
+
class KDiffusionSampler:
|
284 |
+
def __init__(self, funcname, sd_model):
|
285 |
+
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
|
286 |
+
self.funcname = funcname
|
287 |
+
self.func = getattr(k_diffusion.sampling, self.funcname)
|
288 |
+
self.extra_params = sampler_extra_params.get(funcname, [])
|
289 |
+
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
290 |
+
self.sampler_noises = None
|
291 |
+
self.sampler_noise_index = 0
|
292 |
+
self.stop_at = None
|
293 |
+
self.eta = None
|
294 |
+
self.default_eta = 1.0
|
295 |
+
self.config = None
|
296 |
+
|
297 |
+
def callback_state(self, d):
|
298 |
+
store_latent(d["denoised"])
|
299 |
+
|
300 |
+
def number_of_needed_noises(self, p):
|
301 |
+
return p.steps
|
302 |
+
|
303 |
+
def randn_like(self, x):
|
304 |
+
noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
|
305 |
+
|
306 |
+
if noise is not None and x.shape == noise.shape:
|
307 |
+
res = noise
|
308 |
+
else:
|
309 |
+
res = torch.randn_like(x)
|
310 |
+
|
311 |
+
self.sampler_noise_index += 1
|
312 |
+
return res
|
313 |
+
|
314 |
+
def initialize(self, p):
|
315 |
+
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
316 |
+
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
317 |
+
self.model_wrap.step = 0
|
318 |
+
self.sampler_noise_index = 0
|
319 |
+
self.eta = p.eta or opts.eta_ancestral
|
320 |
+
|
321 |
+
if hasattr(k_diffusion.sampling, 'trange'):
|
322 |
+
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
|
323 |
+
|
324 |
+
if self.sampler_noises is not None:
|
325 |
+
k_diffusion.sampling.torch = TorchHijack(self)
|
326 |
+
|
327 |
+
extra_params_kwargs = {}
|
328 |
+
for param_name in self.extra_params:
|
329 |
+
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
330 |
+
extra_params_kwargs[param_name] = getattr(p, param_name)
|
331 |
+
|
332 |
+
if 'eta' in inspect.signature(self.func).parameters:
|
333 |
+
extra_params_kwargs['eta'] = self.eta
|
334 |
+
|
335 |
+
return extra_params_kwargs
|
336 |
+
|
337 |
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
338 |
+
steps, t_enc = setup_img2img_steps(p, steps)
|
339 |
+
|
340 |
+
if p.sampler_noise_scheduler_override:
|
341 |
+
sigmas = p.sampler_noise_scheduler_override(steps)
|
342 |
+
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
343 |
+
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
|
344 |
+
else:
|
345 |
+
sigmas = self.model_wrap.get_sigmas(steps)
|
346 |
+
|
347 |
+
noise = noise * sigmas[steps - t_enc - 1]
|
348 |
+
xi = x + noise
|
349 |
+
|
350 |
+
extra_params_kwargs = self.initialize(p)
|
351 |
+
|
352 |
+
sigma_sched = sigmas[steps - t_enc - 1:]
|
353 |
+
|
354 |
+
self.model_wrap_cfg.init_latent = x
|
355 |
+
|
356 |
+
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
357 |
+
|
358 |
+
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
359 |
+
steps = steps or p.steps
|
360 |
+
|
361 |
+
if p.sampler_noise_scheduler_override:
|
362 |
+
sigmas = p.sampler_noise_scheduler_override(steps)
|
363 |
+
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
364 |
+
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
|
365 |
+
else:
|
366 |
+
sigmas = self.model_wrap.get_sigmas(steps)
|
367 |
+
|
368 |
+
x = x * sigmas[0]
|
369 |
+
|
370 |
+
extra_params_kwargs = self.initialize(p)
|
371 |
+
if 'sigma_min' in inspect.signature(self.func).parameters:
|
372 |
+
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
373 |
+
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
374 |
+
if 'n' in inspect.signature(self.func).parameters:
|
375 |
+
extra_params_kwargs['n'] = steps
|
376 |
+
else:
|
377 |
+
extra_params_kwargs['sigmas'] = sigmas
|
378 |
+
samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
379 |
+
return samples
|
380 |
+
|
modules/shared.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import tqdm
|
9 |
+
|
10 |
+
import modules.artists
|
11 |
+
import modules.interrogate
|
12 |
+
import modules.memmon
|
13 |
+
import modules.sd_models
|
14 |
+
import modules.styles
|
15 |
+
import modules.devices as devices
|
16 |
+
from modules import sd_samplers, hypernetwork
|
17 |
+
from modules.paths import models_path, script_path, sd_path
|
18 |
+
|
19 |
+
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
20 |
+
default_sd_model_file = sd_model_file
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
23 |
+
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
24 |
+
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
25 |
+
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
26 |
+
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
27 |
+
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
28 |
+
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
29 |
+
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
30 |
+
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
31 |
+
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
32 |
+
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
33 |
+
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
34 |
+
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
35 |
+
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
36 |
+
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
37 |
+
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
38 |
+
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
39 |
+
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
40 |
+
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
41 |
+
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
42 |
+
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
43 |
+
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
|
44 |
+
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
|
45 |
+
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
|
46 |
+
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
47 |
+
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
48 |
+
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
49 |
+
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
|
50 |
+
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
51 |
+
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
52 |
+
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
53 |
+
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
|
54 |
+
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
55 |
+
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
|
56 |
+
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
57 |
+
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
58 |
+
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
|
59 |
+
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
60 |
+
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
61 |
+
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
62 |
+
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
63 |
+
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
64 |
+
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
65 |
+
|
66 |
+
|
67 |
+
cmd_opts = parser.parse_args()
|
68 |
+
|
69 |
+
devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
70 |
+
(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
|
71 |
+
|
72 |
+
device = devices.device
|
73 |
+
|
74 |
+
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
75 |
+
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
76 |
+
|
77 |
+
config_filename = cmd_opts.ui_settings_file
|
78 |
+
|
79 |
+
hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks'))
|
80 |
+
|
81 |
+
|
82 |
+
def selected_hypernetwork():
|
83 |
+
return hypernetworks.get(opts.sd_hypernetwork, None)
|
84 |
+
|
85 |
+
|
86 |
+
class State:
|
87 |
+
interrupted = False
|
88 |
+
job = ""
|
89 |
+
job_no = 0
|
90 |
+
job_count = 0
|
91 |
+
job_timestamp = '0'
|
92 |
+
sampling_step = 0
|
93 |
+
sampling_steps = 0
|
94 |
+
current_latent = None
|
95 |
+
current_image = None
|
96 |
+
current_image_sampling_step = 0
|
97 |
+
textinfo = None
|
98 |
+
|
99 |
+
def interrupt(self):
|
100 |
+
self.interrupted = True
|
101 |
+
|
102 |
+
def nextjob(self):
|
103 |
+
self.job_no += 1
|
104 |
+
self.sampling_step = 0
|
105 |
+
self.current_image_sampling_step = 0
|
106 |
+
|
107 |
+
def get_job_timestamp(self):
|
108 |
+
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
109 |
+
|
110 |
+
|
111 |
+
state = State()
|
112 |
+
|
113 |
+
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
|
114 |
+
|
115 |
+
styles_filename = cmd_opts.styles_file
|
116 |
+
prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
117 |
+
|
118 |
+
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
119 |
+
|
120 |
+
face_restorers = []
|
121 |
+
# This was moved to webui.py with the other model "setup" calls.
|
122 |
+
# modules.sd_models.list_models()
|
123 |
+
|
124 |
+
|
125 |
+
def realesrgan_models_names():
|
126 |
+
import modules.realesrgan_model
|
127 |
+
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
128 |
+
|
129 |
+
|
130 |
+
class OptionInfo:
|
131 |
+
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None):
|
132 |
+
self.default = default
|
133 |
+
self.label = label
|
134 |
+
self.component = component
|
135 |
+
self.component_args = component_args
|
136 |
+
self.onchange = onchange
|
137 |
+
self.section = None
|
138 |
+
|
139 |
+
|
140 |
+
def options_section(section_identifer, options_dict):
|
141 |
+
for k, v in options_dict.items():
|
142 |
+
v.section = section_identifer
|
143 |
+
|
144 |
+
return options_dict
|
145 |
+
|
146 |
+
|
147 |
+
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
148 |
+
|
149 |
+
options_templates = {}
|
150 |
+
|
151 |
+
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
152 |
+
"samples_save": OptionInfo(True, "Always save all generated images"),
|
153 |
+
"samples_format": OptionInfo('png', 'File format for images'),
|
154 |
+
"samples_filename_pattern": OptionInfo("", "Images filename pattern"),
|
155 |
+
|
156 |
+
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
157 |
+
"grid_format": OptionInfo('png', 'File format for grids'),
|
158 |
+
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
159 |
+
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
160 |
+
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
161 |
+
|
162 |
+
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
163 |
+
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
164 |
+
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
165 |
+
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
166 |
+
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
|
167 |
+
|
168 |
+
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
|
169 |
+
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
170 |
+
}))
|
171 |
+
|
172 |
+
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
173 |
+
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
174 |
+
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
175 |
+
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
176 |
+
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
|
177 |
+
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
|
178 |
+
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
179 |
+
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
180 |
+
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
181 |
+
}))
|
182 |
+
|
183 |
+
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
184 |
+
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
|
185 |
+
"grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
|
186 |
+
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
187 |
+
"directories_filename_pattern": OptionInfo("", "Directory name pattern"),
|
188 |
+
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
|
189 |
+
}))
|
190 |
+
|
191 |
+
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
192 |
+
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
193 |
+
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
194 |
+
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
195 |
+
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
|
196 |
+
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
197 |
+
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
198 |
+
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
199 |
+
}))
|
200 |
+
|
201 |
+
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
202 |
+
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
203 |
+
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
204 |
+
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
205 |
+
}))
|
206 |
+
|
207 |
+
options_templates.update(options_section(('system', "System"), {
|
208 |
+
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
209 |
+
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
210 |
+
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
211 |
+
}))
|
212 |
+
|
213 |
+
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
214 |
+
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
|
215 |
+
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
|
216 |
+
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
217 |
+
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
218 |
+
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
219 |
+
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
220 |
+
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
221 |
+
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
222 |
+
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
223 |
+
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
224 |
+
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
225 |
+
}))
|
226 |
+
|
227 |
+
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
228 |
+
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
229 |
+
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
|
230 |
+
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
231 |
+
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
232 |
+
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
233 |
+
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
|
234 |
+
}))
|
235 |
+
|
236 |
+
options_templates.update(options_section(('ui', "User interface"), {
|
237 |
+
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
238 |
+
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
|
239 |
+
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
240 |
+
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
241 |
+
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
242 |
+
"font": OptionInfo("", "Font for image grids that have text"),
|
243 |
+
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
244 |
+
"js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
245 |
+
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
246 |
+
}))
|
247 |
+
|
248 |
+
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
249 |
+
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
|
250 |
+
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
251 |
+
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
252 |
+
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
253 |
+
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
254 |
+
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
255 |
+
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
256 |
+
}))
|
257 |
+
|
258 |
+
|
259 |
+
class Options:
|
260 |
+
data = None
|
261 |
+
data_labels = options_templates
|
262 |
+
typemap = {int: float}
|
263 |
+
|
264 |
+
def __init__(self):
|
265 |
+
self.data = {k: v.default for k, v in self.data_labels.items()}
|
266 |
+
|
267 |
+
def __setattr__(self, key, value):
|
268 |
+
if self.data is not None:
|
269 |
+
if key in self.data:
|
270 |
+
self.data[key] = value
|
271 |
+
|
272 |
+
return super(Options, self).__setattr__(key, value)
|
273 |
+
|
274 |
+
def __getattr__(self, item):
|
275 |
+
if self.data is not None:
|
276 |
+
if item in self.data:
|
277 |
+
return self.data[item]
|
278 |
+
|
279 |
+
if item in self.data_labels:
|
280 |
+
return self.data_labels[item].default
|
281 |
+
|
282 |
+
return super(Options, self).__getattribute__(item)
|
283 |
+
|
284 |
+
def save(self, filename):
|
285 |
+
with open(filename, "w", encoding="utf8") as file:
|
286 |
+
json.dump(self.data, file)
|
287 |
+
|
288 |
+
def same_type(self, x, y):
|
289 |
+
if x is None or y is None:
|
290 |
+
return True
|
291 |
+
|
292 |
+
type_x = self.typemap.get(type(x), type(x))
|
293 |
+
type_y = self.typemap.get(type(y), type(y))
|
294 |
+
|
295 |
+
return type_x == type_y
|
296 |
+
|
297 |
+
def load(self, filename):
|
298 |
+
with open(filename, "r", encoding="utf8") as file:
|
299 |
+
self.data = json.load(file)
|
300 |
+
|
301 |
+
bad_settings = 0
|
302 |
+
for k, v in self.data.items():
|
303 |
+
info = self.data_labels.get(k, None)
|
304 |
+
if info is not None and not self.same_type(info.default, v):
|
305 |
+
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
|
306 |
+
bad_settings += 1
|
307 |
+
|
308 |
+
if bad_settings > 0:
|
309 |
+
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
|
310 |
+
|
311 |
+
def onchange(self, key, func):
|
312 |
+
item = self.data_labels.get(key)
|
313 |
+
item.onchange = func
|
314 |
+
|
315 |
+
def dumpjson(self):
|
316 |
+
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
|
317 |
+
return json.dumps(d)
|
318 |
+
|
319 |
+
|
320 |
+
opts = Options()
|
321 |
+
if os.path.exists(config_filename):
|
322 |
+
opts.load(config_filename)
|
323 |
+
|
324 |
+
sd_upscalers = []
|
325 |
+
|
326 |
+
sd_model = None
|
327 |
+
|
328 |
+
progress_print_out = sys.stdout
|
329 |
+
|
330 |
+
|
331 |
+
class TotalTQDM:
|
332 |
+
def __init__(self):
|
333 |
+
self._tqdm = None
|
334 |
+
|
335 |
+
def reset(self):
|
336 |
+
self._tqdm = tqdm.tqdm(
|
337 |
+
desc="Total progress",
|
338 |
+
total=state.job_count * state.sampling_steps,
|
339 |
+
position=1,
|
340 |
+
file=progress_print_out
|
341 |
+
)
|
342 |
+
|
343 |
+
def update(self):
|
344 |
+
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
345 |
+
return
|
346 |
+
if self._tqdm is None:
|
347 |
+
self.reset()
|
348 |
+
self._tqdm.update()
|
349 |
+
|
350 |
+
def updateTotal(self, new_total):
|
351 |
+
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
352 |
+
return
|
353 |
+
if self._tqdm is None:
|
354 |
+
self.reset()
|
355 |
+
self._tqdm.total=new_total
|
356 |
+
|
357 |
+
def clear(self):
|
358 |
+
if self._tqdm is not None:
|
359 |
+
self._tqdm.close()
|
360 |
+
self._tqdm = None
|
361 |
+
|
362 |
+
|
363 |
+
total_tqdm = TotalTQDM()
|
364 |
+
|
365 |
+
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
366 |
+
mem_mon.start()
|
modules/styles.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
import csv
|
5 |
+
import os
|
6 |
+
import os.path
|
7 |
+
import typing
|
8 |
+
import collections.abc as abc
|
9 |
+
import tempfile
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
if typing.TYPE_CHECKING:
|
13 |
+
# Only import this when code is being type-checked, it doesn't have any effect at runtime
|
14 |
+
from .processing import StableDiffusionProcessing
|
15 |
+
|
16 |
+
|
17 |
+
class PromptStyle(typing.NamedTuple):
|
18 |
+
name: str
|
19 |
+
prompt: str
|
20 |
+
negative_prompt: str
|
21 |
+
|
22 |
+
|
23 |
+
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
24 |
+
if "{prompt}" in style_prompt:
|
25 |
+
res = style_prompt.replace("{prompt}", prompt)
|
26 |
+
else:
|
27 |
+
parts = filter(None, (prompt.strip(), style_prompt.strip()))
|
28 |
+
res = ", ".join(parts)
|
29 |
+
|
30 |
+
return res
|
31 |
+
|
32 |
+
|
33 |
+
def apply_styles_to_prompt(prompt, styles):
|
34 |
+
for style in styles:
|
35 |
+
prompt = merge_prompts(style, prompt)
|
36 |
+
|
37 |
+
return prompt
|
38 |
+
|
39 |
+
|
40 |
+
class StyleDatabase:
|
41 |
+
def __init__(self, path: str):
|
42 |
+
self.no_style = PromptStyle("None", "", "")
|
43 |
+
self.styles = {"None": self.no_style}
|
44 |
+
|
45 |
+
if not os.path.exists(path):
|
46 |
+
return
|
47 |
+
|
48 |
+
with open(path, "r", encoding="utf8", newline='') as file:
|
49 |
+
reader = csv.DictReader(file)
|
50 |
+
for row in reader:
|
51 |
+
# Support loading old CSV format with "name, text"-columns
|
52 |
+
prompt = row["prompt"] if "prompt" in row else row["text"]
|
53 |
+
negative_prompt = row.get("negative_prompt", "")
|
54 |
+
self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
|
55 |
+
|
56 |
+
def get_style_prompts(self, styles):
|
57 |
+
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
58 |
+
|
59 |
+
def get_negative_style_prompts(self, styles):
|
60 |
+
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
61 |
+
|
62 |
+
def apply_styles_to_prompt(self, prompt, styles):
|
63 |
+
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
|
64 |
+
|
65 |
+
def apply_negative_styles_to_prompt(self, prompt, styles):
|
66 |
+
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
|
67 |
+
|
68 |
+
def apply_styles(self, p: StableDiffusionProcessing) -> None:
|
69 |
+
if isinstance(p.prompt, list):
|
70 |
+
p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
|
71 |
+
else:
|
72 |
+
p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
|
73 |
+
|
74 |
+
if isinstance(p.negative_prompt, list):
|
75 |
+
p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
|
76 |
+
else:
|
77 |
+
p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
|
78 |
+
|
79 |
+
def save_styles(self, path: str) -> None:
|
80 |
+
# Write to temporary file first, so we don't nuke the file if something goes wrong
|
81 |
+
fd, temp_path = tempfile.mkstemp(".csv")
|
82 |
+
with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
|
83 |
+
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
84 |
+
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
85 |
+
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
86 |
+
writer.writeheader()
|
87 |
+
writer.writerows(style._asdict() for k, style in self.styles.items())
|
88 |
+
|
89 |
+
# Always keep a backup file around
|
90 |
+
if os.path.exists(path):
|
91 |
+
shutil.move(path, path + ".bak")
|
92 |
+
shutil.move(temp_path, path)
|
modules/swinir_model.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from modules import modelloader
|
11 |
+
from modules.paths import models_path
|
12 |
+
from modules.shared import cmd_opts, opts, device
|
13 |
+
from modules.swinir_model_arch import SwinIR as net
|
14 |
+
from modules.upscaler import Upscaler, UpscalerData
|
15 |
+
|
16 |
+
precision_scope = (
|
17 |
+
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
class UpscalerSwinIR(Upscaler):
|
22 |
+
def __init__(self, dirname):
|
23 |
+
self.name = "SwinIR"
|
24 |
+
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
|
25 |
+
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
|
26 |
+
"-L_x4_GAN.pth "
|
27 |
+
self.model_name = "SwinIR 4x"
|
28 |
+
self.model_path = os.path.join(models_path, self.name)
|
29 |
+
self.user_path = dirname
|
30 |
+
super().__init__()
|
31 |
+
scalers = []
|
32 |
+
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
33 |
+
for model in model_files:
|
34 |
+
if "http" in model:
|
35 |
+
name = self.model_name
|
36 |
+
else:
|
37 |
+
name = modelloader.friendly_name(model)
|
38 |
+
model_data = UpscalerData(name, model, self)
|
39 |
+
scalers.append(model_data)
|
40 |
+
self.scalers = scalers
|
41 |
+
|
42 |
+
def do_upscale(self, img, model_file):
|
43 |
+
model = self.load_model(model_file)
|
44 |
+
if model is None:
|
45 |
+
return img
|
46 |
+
model = model.to(device)
|
47 |
+
img = upscale(img, model)
|
48 |
+
try:
|
49 |
+
torch.cuda.empty_cache()
|
50 |
+
except:
|
51 |
+
pass
|
52 |
+
return img
|
53 |
+
|
54 |
+
def load_model(self, path, scale=4):
|
55 |
+
if "http" in path:
|
56 |
+
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
|
57 |
+
filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
|
58 |
+
else:
|
59 |
+
filename = path
|
60 |
+
if filename is None or not os.path.exists(filename):
|
61 |
+
return None
|
62 |
+
model = net(
|
63 |
+
upscale=scale,
|
64 |
+
in_chans=3,
|
65 |
+
img_size=64,
|
66 |
+
window_size=8,
|
67 |
+
img_range=1.0,
|
68 |
+
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
69 |
+
embed_dim=240,
|
70 |
+
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
71 |
+
mlp_ratio=2,
|
72 |
+
upsampler="nearest+conv",
|
73 |
+
resi_connection="3conv",
|
74 |
+
)
|
75 |
+
|
76 |
+
pretrained_model = torch.load(filename)
|
77 |
+
model.load_state_dict(pretrained_model["params_ema"], strict=True)
|
78 |
+
if not cmd_opts.no_half:
|
79 |
+
model = model.half()
|
80 |
+
return model
|
81 |
+
|
82 |
+
|
83 |
+
def upscale(
|
84 |
+
img,
|
85 |
+
model,
|
86 |
+
tile=opts.SWIN_tile,
|
87 |
+
tile_overlap=opts.SWIN_tile_overlap,
|
88 |
+
window_size=8,
|
89 |
+
scale=4,
|
90 |
+
):
|
91 |
+
img = np.array(img)
|
92 |
+
img = img[:, :, ::-1]
|
93 |
+
img = np.moveaxis(img, 2, 0) / 255
|
94 |
+
img = torch.from_numpy(img).float()
|
95 |
+
img = img.unsqueeze(0).to(device)
|
96 |
+
with torch.no_grad(), precision_scope("cuda"):
|
97 |
+
_, _, h_old, w_old = img.size()
|
98 |
+
h_pad = (h_old // window_size + 1) * window_size - h_old
|
99 |
+
w_pad = (w_old // window_size + 1) * window_size - w_old
|
100 |
+
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
101 |
+
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
102 |
+
output = inference(img, model, tile, tile_overlap, window_size, scale)
|
103 |
+
output = output[..., : h_old * scale, : w_old * scale]
|
104 |
+
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
105 |
+
if output.ndim == 3:
|
106 |
+
output = np.transpose(
|
107 |
+
output[[2, 1, 0], :, :], (1, 2, 0)
|
108 |
+
) # CHW-RGB to HCW-BGR
|
109 |
+
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
110 |
+
return Image.fromarray(output, "RGB")
|
111 |
+
|
112 |
+
|
113 |
+
def inference(img, model, tile, tile_overlap, window_size, scale):
|
114 |
+
# test the image tile by tile
|
115 |
+
b, c, h, w = img.size()
|
116 |
+
tile = min(tile, h, w)
|
117 |
+
assert tile % window_size == 0, "tile size should be a multiple of window_size"
|
118 |
+
sf = scale
|
119 |
+
|
120 |
+
stride = tile - tile_overlap
|
121 |
+
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
122 |
+
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
123 |
+
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
|
124 |
+
W = torch.zeros_like(E, dtype=torch.half, device=device)
|
125 |
+
|
126 |
+
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
127 |
+
for h_idx in h_idx_list:
|
128 |
+
for w_idx in w_idx_list:
|
129 |
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
130 |
+
out_patch = model(in_patch)
|
131 |
+
out_patch_mask = torch.ones_like(out_patch)
|
132 |
+
|
133 |
+
E[
|
134 |
+
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
135 |
+
].add_(out_patch)
|
136 |
+
W[
|
137 |
+
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
138 |
+
].add_(out_patch_mask)
|
139 |
+
pbar.update(1)
|
140 |
+
output = E.div_(W)
|
141 |
+
|
142 |
+
return output
|
modules/swinir_model_arch.py
ADDED
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------------
|
2 |
+
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
3 |
+
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
4 |
+
# -----------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint as checkpoint
|
11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
12 |
+
|
13 |
+
|
14 |
+
class Mlp(nn.Module):
|
15 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
16 |
+
super().__init__()
|
17 |
+
out_features = out_features or in_features
|
18 |
+
hidden_features = hidden_features or in_features
|
19 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
20 |
+
self.act = act_layer()
|
21 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
22 |
+
self.drop = nn.Dropout(drop)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.fc1(x)
|
26 |
+
x = self.act(x)
|
27 |
+
x = self.drop(x)
|
28 |
+
x = self.fc2(x)
|
29 |
+
x = self.drop(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
def window_partition(x, window_size):
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
x: (B, H, W, C)
|
37 |
+
window_size (int): window size
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
windows: (num_windows*B, window_size, window_size, C)
|
41 |
+
"""
|
42 |
+
B, H, W, C = x.shape
|
43 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
44 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
45 |
+
return windows
|
46 |
+
|
47 |
+
|
48 |
+
def window_reverse(windows, window_size, H, W):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
52 |
+
window_size (int): Window size
|
53 |
+
H (int): Height of image
|
54 |
+
W (int): Width of image
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
x: (B, H, W, C)
|
58 |
+
"""
|
59 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
60 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
61 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class WindowAttention(nn.Module):
|
66 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
67 |
+
It supports both of shifted and non-shifted window.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
dim (int): Number of input channels.
|
71 |
+
window_size (tuple[int]): The height and width of the window.
|
72 |
+
num_heads (int): Number of attention heads.
|
73 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
74 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
75 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
76 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
80 |
+
|
81 |
+
super().__init__()
|
82 |
+
self.dim = dim
|
83 |
+
self.window_size = window_size # Wh, Ww
|
84 |
+
self.num_heads = num_heads
|
85 |
+
head_dim = dim // num_heads
|
86 |
+
self.scale = qk_scale or head_dim ** -0.5
|
87 |
+
|
88 |
+
# define a parameter table of relative position bias
|
89 |
+
self.relative_position_bias_table = nn.Parameter(
|
90 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
91 |
+
|
92 |
+
# get pair-wise relative position index for each token inside the window
|
93 |
+
coords_h = torch.arange(self.window_size[0])
|
94 |
+
coords_w = torch.arange(self.window_size[1])
|
95 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
96 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
97 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
98 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
99 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
100 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
101 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
102 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
103 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
104 |
+
|
105 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
106 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
107 |
+
self.proj = nn.Linear(dim, dim)
|
108 |
+
|
109 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
110 |
+
|
111 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
112 |
+
self.softmax = nn.Softmax(dim=-1)
|
113 |
+
|
114 |
+
def forward(self, x, mask=None):
|
115 |
+
"""
|
116 |
+
Args:
|
117 |
+
x: input features with shape of (num_windows*B, N, C)
|
118 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
119 |
+
"""
|
120 |
+
B_, N, C = x.shape
|
121 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
122 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
123 |
+
|
124 |
+
q = q * self.scale
|
125 |
+
attn = (q @ k.transpose(-2, -1))
|
126 |
+
|
127 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
128 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
129 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
130 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
131 |
+
|
132 |
+
if mask is not None:
|
133 |
+
nW = mask.shape[0]
|
134 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
135 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
136 |
+
attn = self.softmax(attn)
|
137 |
+
else:
|
138 |
+
attn = self.softmax(attn)
|
139 |
+
|
140 |
+
attn = self.attn_drop(attn)
|
141 |
+
|
142 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
143 |
+
x = self.proj(x)
|
144 |
+
x = self.proj_drop(x)
|
145 |
+
return x
|
146 |
+
|
147 |
+
def extra_repr(self) -> str:
|
148 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
149 |
+
|
150 |
+
def flops(self, N):
|
151 |
+
# calculate flops for 1 window with token length of N
|
152 |
+
flops = 0
|
153 |
+
# qkv = self.qkv(x)
|
154 |
+
flops += N * self.dim * 3 * self.dim
|
155 |
+
# attn = (q @ k.transpose(-2, -1))
|
156 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
157 |
+
# x = (attn @ v)
|
158 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
159 |
+
# x = self.proj(x)
|
160 |
+
flops += N * self.dim * self.dim
|
161 |
+
return flops
|
162 |
+
|
163 |
+
|
164 |
+
class SwinTransformerBlock(nn.Module):
|
165 |
+
r""" Swin Transformer Block.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
dim (int): Number of input channels.
|
169 |
+
input_resolution (tuple[int]): Input resulotion.
|
170 |
+
num_heads (int): Number of attention heads.
|
171 |
+
window_size (int): Window size.
|
172 |
+
shift_size (int): Shift size for SW-MSA.
|
173 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
174 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
175 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
176 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
177 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
178 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
179 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
180 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
184 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
185 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
186 |
+
super().__init__()
|
187 |
+
self.dim = dim
|
188 |
+
self.input_resolution = input_resolution
|
189 |
+
self.num_heads = num_heads
|
190 |
+
self.window_size = window_size
|
191 |
+
self.shift_size = shift_size
|
192 |
+
self.mlp_ratio = mlp_ratio
|
193 |
+
if min(self.input_resolution) <= self.window_size:
|
194 |
+
# if window size is larger than input resolution, we don't partition windows
|
195 |
+
self.shift_size = 0
|
196 |
+
self.window_size = min(self.input_resolution)
|
197 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
198 |
+
|
199 |
+
self.norm1 = norm_layer(dim)
|
200 |
+
self.attn = WindowAttention(
|
201 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
202 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
203 |
+
|
204 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
205 |
+
self.norm2 = norm_layer(dim)
|
206 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
207 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
208 |
+
|
209 |
+
if self.shift_size > 0:
|
210 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
211 |
+
else:
|
212 |
+
attn_mask = None
|
213 |
+
|
214 |
+
self.register_buffer("attn_mask", attn_mask)
|
215 |
+
|
216 |
+
def calculate_mask(self, x_size):
|
217 |
+
# calculate attention mask for SW-MSA
|
218 |
+
H, W = x_size
|
219 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
220 |
+
h_slices = (slice(0, -self.window_size),
|
221 |
+
slice(-self.window_size, -self.shift_size),
|
222 |
+
slice(-self.shift_size, None))
|
223 |
+
w_slices = (slice(0, -self.window_size),
|
224 |
+
slice(-self.window_size, -self.shift_size),
|
225 |
+
slice(-self.shift_size, None))
|
226 |
+
cnt = 0
|
227 |
+
for h in h_slices:
|
228 |
+
for w in w_slices:
|
229 |
+
img_mask[:, h, w, :] = cnt
|
230 |
+
cnt += 1
|
231 |
+
|
232 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
233 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
234 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
235 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
236 |
+
|
237 |
+
return attn_mask
|
238 |
+
|
239 |
+
def forward(self, x, x_size):
|
240 |
+
H, W = x_size
|
241 |
+
B, L, C = x.shape
|
242 |
+
# assert L == H * W, "input feature has wrong size"
|
243 |
+
|
244 |
+
shortcut = x
|
245 |
+
x = self.norm1(x)
|
246 |
+
x = x.view(B, H, W, C)
|
247 |
+
|
248 |
+
# cyclic shift
|
249 |
+
if self.shift_size > 0:
|
250 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
251 |
+
else:
|
252 |
+
shifted_x = x
|
253 |
+
|
254 |
+
# partition windows
|
255 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
256 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
257 |
+
|
258 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
259 |
+
if self.input_resolution == x_size:
|
260 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
261 |
+
else:
|
262 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
263 |
+
|
264 |
+
# merge windows
|
265 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
266 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
267 |
+
|
268 |
+
# reverse cyclic shift
|
269 |
+
if self.shift_size > 0:
|
270 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
271 |
+
else:
|
272 |
+
x = shifted_x
|
273 |
+
x = x.view(B, H * W, C)
|
274 |
+
|
275 |
+
# FFN
|
276 |
+
x = shortcut + self.drop_path(x)
|
277 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
278 |
+
|
279 |
+
return x
|
280 |
+
|
281 |
+
def extra_repr(self) -> str:
|
282 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
283 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
284 |
+
|
285 |
+
def flops(self):
|
286 |
+
flops = 0
|
287 |
+
H, W = self.input_resolution
|
288 |
+
# norm1
|
289 |
+
flops += self.dim * H * W
|
290 |
+
# W-MSA/SW-MSA
|
291 |
+
nW = H * W / self.window_size / self.window_size
|
292 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
293 |
+
# mlp
|
294 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
295 |
+
# norm2
|
296 |
+
flops += self.dim * H * W
|
297 |
+
return flops
|
298 |
+
|
299 |
+
|
300 |
+
class PatchMerging(nn.Module):
|
301 |
+
r""" Patch Merging Layer.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
305 |
+
dim (int): Number of input channels.
|
306 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
307 |
+
"""
|
308 |
+
|
309 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
310 |
+
super().__init__()
|
311 |
+
self.input_resolution = input_resolution
|
312 |
+
self.dim = dim
|
313 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
314 |
+
self.norm = norm_layer(4 * dim)
|
315 |
+
|
316 |
+
def forward(self, x):
|
317 |
+
"""
|
318 |
+
x: B, H*W, C
|
319 |
+
"""
|
320 |
+
H, W = self.input_resolution
|
321 |
+
B, L, C = x.shape
|
322 |
+
assert L == H * W, "input feature has wrong size"
|
323 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
324 |
+
|
325 |
+
x = x.view(B, H, W, C)
|
326 |
+
|
327 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
328 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
329 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
330 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
331 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
332 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
333 |
+
|
334 |
+
x = self.norm(x)
|
335 |
+
x = self.reduction(x)
|
336 |
+
|
337 |
+
return x
|
338 |
+
|
339 |
+
def extra_repr(self) -> str:
|
340 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
341 |
+
|
342 |
+
def flops(self):
|
343 |
+
H, W = self.input_resolution
|
344 |
+
flops = H * W * self.dim
|
345 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
346 |
+
return flops
|
347 |
+
|
348 |
+
|
349 |
+
class BasicLayer(nn.Module):
|
350 |
+
""" A basic Swin Transformer layer for one stage.
|
351 |
+
|
352 |
+
Args:
|
353 |
+
dim (int): Number of input channels.
|
354 |
+
input_resolution (tuple[int]): Input resolution.
|
355 |
+
depth (int): Number of blocks.
|
356 |
+
num_heads (int): Number of attention heads.
|
357 |
+
window_size (int): Local window size.
|
358 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
359 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
360 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
361 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
362 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
363 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
364 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
365 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
366 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
370 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
371 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
372 |
+
|
373 |
+
super().__init__()
|
374 |
+
self.dim = dim
|
375 |
+
self.input_resolution = input_resolution
|
376 |
+
self.depth = depth
|
377 |
+
self.use_checkpoint = use_checkpoint
|
378 |
+
|
379 |
+
# build blocks
|
380 |
+
self.blocks = nn.ModuleList([
|
381 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
382 |
+
num_heads=num_heads, window_size=window_size,
|
383 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
384 |
+
mlp_ratio=mlp_ratio,
|
385 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
386 |
+
drop=drop, attn_drop=attn_drop,
|
387 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
388 |
+
norm_layer=norm_layer)
|
389 |
+
for i in range(depth)])
|
390 |
+
|
391 |
+
# patch merging layer
|
392 |
+
if downsample is not None:
|
393 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
394 |
+
else:
|
395 |
+
self.downsample = None
|
396 |
+
|
397 |
+
def forward(self, x, x_size):
|
398 |
+
for blk in self.blocks:
|
399 |
+
if self.use_checkpoint:
|
400 |
+
x = checkpoint.checkpoint(blk, x, x_size)
|
401 |
+
else:
|
402 |
+
x = blk(x, x_size)
|
403 |
+
if self.downsample is not None:
|
404 |
+
x = self.downsample(x)
|
405 |
+
return x
|
406 |
+
|
407 |
+
def extra_repr(self) -> str:
|
408 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
409 |
+
|
410 |
+
def flops(self):
|
411 |
+
flops = 0
|
412 |
+
for blk in self.blocks:
|
413 |
+
flops += blk.flops()
|
414 |
+
if self.downsample is not None:
|
415 |
+
flops += self.downsample.flops()
|
416 |
+
return flops
|
417 |
+
|
418 |
+
|
419 |
+
class RSTB(nn.Module):
|
420 |
+
"""Residual Swin Transformer Block (RSTB).
|
421 |
+
|
422 |
+
Args:
|
423 |
+
dim (int): Number of input channels.
|
424 |
+
input_resolution (tuple[int]): Input resolution.
|
425 |
+
depth (int): Number of blocks.
|
426 |
+
num_heads (int): Number of attention heads.
|
427 |
+
window_size (int): Local window size.
|
428 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
429 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
430 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
431 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
432 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
433 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
434 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
435 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
436 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
437 |
+
img_size: Input image size.
|
438 |
+
patch_size: Patch size.
|
439 |
+
resi_connection: The convolutional block before residual connection.
|
440 |
+
"""
|
441 |
+
|
442 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
443 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
444 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
445 |
+
img_size=224, patch_size=4, resi_connection='1conv'):
|
446 |
+
super(RSTB, self).__init__()
|
447 |
+
|
448 |
+
self.dim = dim
|
449 |
+
self.input_resolution = input_resolution
|
450 |
+
|
451 |
+
self.residual_group = BasicLayer(dim=dim,
|
452 |
+
input_resolution=input_resolution,
|
453 |
+
depth=depth,
|
454 |
+
num_heads=num_heads,
|
455 |
+
window_size=window_size,
|
456 |
+
mlp_ratio=mlp_ratio,
|
457 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
458 |
+
drop=drop, attn_drop=attn_drop,
|
459 |
+
drop_path=drop_path,
|
460 |
+
norm_layer=norm_layer,
|
461 |
+
downsample=downsample,
|
462 |
+
use_checkpoint=use_checkpoint)
|
463 |
+
|
464 |
+
if resi_connection == '1conv':
|
465 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
466 |
+
elif resi_connection == '3conv':
|
467 |
+
# to save parameters and memory
|
468 |
+
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
469 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
470 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
471 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
472 |
+
|
473 |
+
self.patch_embed = PatchEmbed(
|
474 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
475 |
+
norm_layer=None)
|
476 |
+
|
477 |
+
self.patch_unembed = PatchUnEmbed(
|
478 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
479 |
+
norm_layer=None)
|
480 |
+
|
481 |
+
def forward(self, x, x_size):
|
482 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
483 |
+
|
484 |
+
def flops(self):
|
485 |
+
flops = 0
|
486 |
+
flops += self.residual_group.flops()
|
487 |
+
H, W = self.input_resolution
|
488 |
+
flops += H * W * self.dim * self.dim * 9
|
489 |
+
flops += self.patch_embed.flops()
|
490 |
+
flops += self.patch_unembed.flops()
|
491 |
+
|
492 |
+
return flops
|
493 |
+
|
494 |
+
|
495 |
+
class PatchEmbed(nn.Module):
|
496 |
+
r""" Image to Patch Embedding
|
497 |
+
|
498 |
+
Args:
|
499 |
+
img_size (int): Image size. Default: 224.
|
500 |
+
patch_size (int): Patch token size. Default: 4.
|
501 |
+
in_chans (int): Number of input image channels. Default: 3.
|
502 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
503 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
504 |
+
"""
|
505 |
+
|
506 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
507 |
+
super().__init__()
|
508 |
+
img_size = to_2tuple(img_size)
|
509 |
+
patch_size = to_2tuple(patch_size)
|
510 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
511 |
+
self.img_size = img_size
|
512 |
+
self.patch_size = patch_size
|
513 |
+
self.patches_resolution = patches_resolution
|
514 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
515 |
+
|
516 |
+
self.in_chans = in_chans
|
517 |
+
self.embed_dim = embed_dim
|
518 |
+
|
519 |
+
if norm_layer is not None:
|
520 |
+
self.norm = norm_layer(embed_dim)
|
521 |
+
else:
|
522 |
+
self.norm = None
|
523 |
+
|
524 |
+
def forward(self, x):
|
525 |
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
526 |
+
if self.norm is not None:
|
527 |
+
x = self.norm(x)
|
528 |
+
return x
|
529 |
+
|
530 |
+
def flops(self):
|
531 |
+
flops = 0
|
532 |
+
H, W = self.img_size
|
533 |
+
if self.norm is not None:
|
534 |
+
flops += H * W * self.embed_dim
|
535 |
+
return flops
|
536 |
+
|
537 |
+
|
538 |
+
class PatchUnEmbed(nn.Module):
|
539 |
+
r""" Image to Patch Unembedding
|
540 |
+
|
541 |
+
Args:
|
542 |
+
img_size (int): Image size. Default: 224.
|
543 |
+
patch_size (int): Patch token size. Default: 4.
|
544 |
+
in_chans (int): Number of input image channels. Default: 3.
|
545 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
546 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
547 |
+
"""
|
548 |
+
|
549 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
550 |
+
super().__init__()
|
551 |
+
img_size = to_2tuple(img_size)
|
552 |
+
patch_size = to_2tuple(patch_size)
|
553 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
554 |
+
self.img_size = img_size
|
555 |
+
self.patch_size = patch_size
|
556 |
+
self.patches_resolution = patches_resolution
|
557 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
558 |
+
|
559 |
+
self.in_chans = in_chans
|
560 |
+
self.embed_dim = embed_dim
|
561 |
+
|
562 |
+
def forward(self, x, x_size):
|
563 |
+
B, HW, C = x.shape
|
564 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
565 |
+
return x
|
566 |
+
|
567 |
+
def flops(self):
|
568 |
+
flops = 0
|
569 |
+
return flops
|
570 |
+
|
571 |
+
|
572 |
+
class Upsample(nn.Sequential):
|
573 |
+
"""Upsample module.
|
574 |
+
|
575 |
+
Args:
|
576 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
577 |
+
num_feat (int): Channel number of intermediate features.
|
578 |
+
"""
|
579 |
+
|
580 |
+
def __init__(self, scale, num_feat):
|
581 |
+
m = []
|
582 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
583 |
+
for _ in range(int(math.log(scale, 2))):
|
584 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
585 |
+
m.append(nn.PixelShuffle(2))
|
586 |
+
elif scale == 3:
|
587 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
588 |
+
m.append(nn.PixelShuffle(3))
|
589 |
+
else:
|
590 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
591 |
+
super(Upsample, self).__init__(*m)
|
592 |
+
|
593 |
+
|
594 |
+
class UpsampleOneStep(nn.Sequential):
|
595 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
596 |
+
Used in lightweight SR to save parameters.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
600 |
+
num_feat (int): Channel number of intermediate features.
|
601 |
+
|
602 |
+
"""
|
603 |
+
|
604 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
605 |
+
self.num_feat = num_feat
|
606 |
+
self.input_resolution = input_resolution
|
607 |
+
m = []
|
608 |
+
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
609 |
+
m.append(nn.PixelShuffle(scale))
|
610 |
+
super(UpsampleOneStep, self).__init__(*m)
|
611 |
+
|
612 |
+
def flops(self):
|
613 |
+
H, W = self.input_resolution
|
614 |
+
flops = H * W * self.num_feat * 3 * 9
|
615 |
+
return flops
|
616 |
+
|
617 |
+
|
618 |
+
class SwinIR(nn.Module):
|
619 |
+
r""" SwinIR
|
620 |
+
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
621 |
+
|
622 |
+
Args:
|
623 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
624 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
625 |
+
in_chans (int): Number of input image channels. Default: 3
|
626 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
627 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
628 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
629 |
+
window_size (int): Window size. Default: 7
|
630 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
631 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
632 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
633 |
+
drop_rate (float): Dropout rate. Default: 0
|
634 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
635 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
636 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
637 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
638 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
639 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
640 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
641 |
+
img_range: Image range. 1. or 255.
|
642 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
643 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
644 |
+
"""
|
645 |
+
|
646 |
+
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
647 |
+
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
648 |
+
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
649 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
650 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
651 |
+
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
652 |
+
**kwargs):
|
653 |
+
super(SwinIR, self).__init__()
|
654 |
+
num_in_ch = in_chans
|
655 |
+
num_out_ch = in_chans
|
656 |
+
num_feat = 64
|
657 |
+
self.img_range = img_range
|
658 |
+
if in_chans == 3:
|
659 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
660 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
661 |
+
else:
|
662 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
663 |
+
self.upscale = upscale
|
664 |
+
self.upsampler = upsampler
|
665 |
+
self.window_size = window_size
|
666 |
+
|
667 |
+
#####################################################################################################
|
668 |
+
################################### 1, shallow feature extraction ###################################
|
669 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
670 |
+
|
671 |
+
#####################################################################################################
|
672 |
+
################################### 2, deep feature extraction ######################################
|
673 |
+
self.num_layers = len(depths)
|
674 |
+
self.embed_dim = embed_dim
|
675 |
+
self.ape = ape
|
676 |
+
self.patch_norm = patch_norm
|
677 |
+
self.num_features = embed_dim
|
678 |
+
self.mlp_ratio = mlp_ratio
|
679 |
+
|
680 |
+
# split image into non-overlapping patches
|
681 |
+
self.patch_embed = PatchEmbed(
|
682 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
683 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
684 |
+
num_patches = self.patch_embed.num_patches
|
685 |
+
patches_resolution = self.patch_embed.patches_resolution
|
686 |
+
self.patches_resolution = patches_resolution
|
687 |
+
|
688 |
+
# merge non-overlapping patches into image
|
689 |
+
self.patch_unembed = PatchUnEmbed(
|
690 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
691 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
692 |
+
|
693 |
+
# absolute position embedding
|
694 |
+
if self.ape:
|
695 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
696 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
697 |
+
|
698 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
699 |
+
|
700 |
+
# stochastic depth
|
701 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
702 |
+
|
703 |
+
# build Residual Swin Transformer blocks (RSTB)
|
704 |
+
self.layers = nn.ModuleList()
|
705 |
+
for i_layer in range(self.num_layers):
|
706 |
+
layer = RSTB(dim=embed_dim,
|
707 |
+
input_resolution=(patches_resolution[0],
|
708 |
+
patches_resolution[1]),
|
709 |
+
depth=depths[i_layer],
|
710 |
+
num_heads=num_heads[i_layer],
|
711 |
+
window_size=window_size,
|
712 |
+
mlp_ratio=self.mlp_ratio,
|
713 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
714 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
715 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
716 |
+
norm_layer=norm_layer,
|
717 |
+
downsample=None,
|
718 |
+
use_checkpoint=use_checkpoint,
|
719 |
+
img_size=img_size,
|
720 |
+
patch_size=patch_size,
|
721 |
+
resi_connection=resi_connection
|
722 |
+
|
723 |
+
)
|
724 |
+
self.layers.append(layer)
|
725 |
+
self.norm = norm_layer(self.num_features)
|
726 |
+
|
727 |
+
# build the last conv layer in deep feature extraction
|
728 |
+
if resi_connection == '1conv':
|
729 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
730 |
+
elif resi_connection == '3conv':
|
731 |
+
# to save parameters and memory
|
732 |
+
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
733 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
734 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
735 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
736 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
737 |
+
|
738 |
+
#####################################################################################################
|
739 |
+
################################ 3, high quality image reconstruction ################################
|
740 |
+
if self.upsampler == 'pixelshuffle':
|
741 |
+
# for classical SR
|
742 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
743 |
+
nn.LeakyReLU(inplace=True))
|
744 |
+
self.upsample = Upsample(upscale, num_feat)
|
745 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
746 |
+
elif self.upsampler == 'pixelshuffledirect':
|
747 |
+
# for lightweight SR (to save parameters)
|
748 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
749 |
+
(patches_resolution[0], patches_resolution[1]))
|
750 |
+
elif self.upsampler == 'nearest+conv':
|
751 |
+
# for real-world SR (less artifacts)
|
752 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
753 |
+
nn.LeakyReLU(inplace=True))
|
754 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
755 |
+
if self.upscale == 4:
|
756 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
757 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
758 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
759 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
760 |
+
else:
|
761 |
+
# for image denoising and JPEG compression artifact reduction
|
762 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
763 |
+
|
764 |
+
self.apply(self._init_weights)
|
765 |
+
|
766 |
+
def _init_weights(self, m):
|
767 |
+
if isinstance(m, nn.Linear):
|
768 |
+
trunc_normal_(m.weight, std=.02)
|
769 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
770 |
+
nn.init.constant_(m.bias, 0)
|
771 |
+
elif isinstance(m, nn.LayerNorm):
|
772 |
+
nn.init.constant_(m.bias, 0)
|
773 |
+
nn.init.constant_(m.weight, 1.0)
|
774 |
+
|
775 |
+
@torch.jit.ignore
|
776 |
+
def no_weight_decay(self):
|
777 |
+
return {'absolute_pos_embed'}
|
778 |
+
|
779 |
+
@torch.jit.ignore
|
780 |
+
def no_weight_decay_keywords(self):
|
781 |
+
return {'relative_position_bias_table'}
|
782 |
+
|
783 |
+
def check_image_size(self, x):
|
784 |
+
_, _, h, w = x.size()
|
785 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
786 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
787 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
788 |
+
return x
|
789 |
+
|
790 |
+
def forward_features(self, x):
|
791 |
+
x_size = (x.shape[2], x.shape[3])
|
792 |
+
x = self.patch_embed(x)
|
793 |
+
if self.ape:
|
794 |
+
x = x + self.absolute_pos_embed
|
795 |
+
x = self.pos_drop(x)
|
796 |
+
|
797 |
+
for layer in self.layers:
|
798 |
+
x = layer(x, x_size)
|
799 |
+
|
800 |
+
x = self.norm(x) # B L C
|
801 |
+
x = self.patch_unembed(x, x_size)
|
802 |
+
|
803 |
+
return x
|
804 |
+
|
805 |
+
def forward(self, x):
|
806 |
+
H, W = x.shape[2:]
|
807 |
+
x = self.check_image_size(x)
|
808 |
+
|
809 |
+
self.mean = self.mean.type_as(x)
|
810 |
+
x = (x - self.mean) * self.img_range
|
811 |
+
|
812 |
+
if self.upsampler == 'pixelshuffle':
|
813 |
+
# for classical SR
|
814 |
+
x = self.conv_first(x)
|
815 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
816 |
+
x = self.conv_before_upsample(x)
|
817 |
+
x = self.conv_last(self.upsample(x))
|
818 |
+
elif self.upsampler == 'pixelshuffledirect':
|
819 |
+
# for lightweight SR
|
820 |
+
x = self.conv_first(x)
|
821 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
822 |
+
x = self.upsample(x)
|
823 |
+
elif self.upsampler == 'nearest+conv':
|
824 |
+
# for real-world SR
|
825 |
+
x = self.conv_first(x)
|
826 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
827 |
+
x = self.conv_before_upsample(x)
|
828 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
829 |
+
if self.upscale == 4:
|
830 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
831 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
832 |
+
else:
|
833 |
+
# for image denoising and JPEG compression artifact reduction
|
834 |
+
x_first = self.conv_first(x)
|
835 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
836 |
+
x = x + self.conv_last(res)
|
837 |
+
|
838 |
+
x = x / self.img_range + self.mean
|
839 |
+
|
840 |
+
return x[:, :, :H*self.upscale, :W*self.upscale]
|
841 |
+
|
842 |
+
def flops(self):
|
843 |
+
flops = 0
|
844 |
+
H, W = self.patches_resolution
|
845 |
+
flops += H * W * 3 * self.embed_dim * 9
|
846 |
+
flops += self.patch_embed.flops()
|
847 |
+
for i, layer in enumerate(self.layers):
|
848 |
+
flops += layer.flops()
|
849 |
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
850 |
+
flops += self.upsample.flops()
|
851 |
+
return flops
|
852 |
+
|
853 |
+
|
854 |
+
if __name__ == '__main__':
|
855 |
+
upscale = 4
|
856 |
+
window_size = 8
|
857 |
+
height = (1024 // upscale // window_size + 1) * window_size
|
858 |
+
width = (720 // upscale // window_size + 1) * window_size
|
859 |
+
model = SwinIR(upscale=2, img_size=(height, width),
|
860 |
+
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
861 |
+
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
862 |
+
print(model)
|
863 |
+
print(height, width, model.flops() / 1e9)
|
864 |
+
|
865 |
+
x = torch.randn((1, 3, height, width))
|
866 |
+
x = model(x)
|
867 |
+
print(x.shape)
|
modules/txt2img.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modules.scripts
|
2 |
+
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
3 |
+
from modules.shared import opts, cmd_opts
|
4 |
+
import modules.shared as shared
|
5 |
+
import modules.processing as processing
|
6 |
+
from modules.ui import plaintext_to_html
|
7 |
+
|
8 |
+
|
9 |
+
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args):
|
10 |
+
p = StableDiffusionProcessingTxt2Img(
|
11 |
+
sd_model=shared.sd_model,
|
12 |
+
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
13 |
+
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
14 |
+
prompt=prompt,
|
15 |
+
styles=[prompt_style, prompt_style2],
|
16 |
+
negative_prompt=negative_prompt,
|
17 |
+
seed=seed,
|
18 |
+
subseed=subseed,
|
19 |
+
subseed_strength=subseed_strength,
|
20 |
+
seed_resize_from_h=seed_resize_from_h,
|
21 |
+
seed_resize_from_w=seed_resize_from_w,
|
22 |
+
seed_enable_extras=seed_enable_extras,
|
23 |
+
sampler_index=sampler_index,
|
24 |
+
batch_size=batch_size,
|
25 |
+
n_iter=n_iter,
|
26 |
+
steps=steps,
|
27 |
+
cfg_scale=cfg_scale,
|
28 |
+
width=width,
|
29 |
+
height=height,
|
30 |
+
restore_faces=restore_faces,
|
31 |
+
tiling=tiling,
|
32 |
+
enable_hr=enable_hr,
|
33 |
+
scale_latent=scale_latent if enable_hr else None,
|
34 |
+
denoising_strength=denoising_strength if enable_hr else None,
|
35 |
+
)
|
36 |
+
|
37 |
+
if cmd_opts.enable_console_prompts:
|
38 |
+
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
39 |
+
|
40 |
+
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
41 |
+
|
42 |
+
if processed is None:
|
43 |
+
processed = process_images(p)
|
44 |
+
|
45 |
+
shared.total_tqdm.clear()
|
46 |
+
|
47 |
+
generation_info_js = processed.js()
|
48 |
+
if opts.samples_log_stdout:
|
49 |
+
print(generation_info_js)
|
50 |
+
|
51 |
+
if opts.do_not_show_images:
|
52 |
+
processed.images = []
|
53 |
+
|
54 |
+
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
55 |
+
|
modules/ui.py
ADDED
@@ -0,0 +1,1414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import html
|
3 |
+
import io
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import mimetypes
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
import traceback
|
12 |
+
import platform
|
13 |
+
import subprocess as sp
|
14 |
+
from functools import reduce
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from PIL import Image, PngImagePlugin
|
19 |
+
import piexif
|
20 |
+
|
21 |
+
import gradio as gr
|
22 |
+
import gradio.utils
|
23 |
+
import gradio.routes
|
24 |
+
|
25 |
+
from modules import sd_hijack
|
26 |
+
from modules.paths import script_path
|
27 |
+
from modules.shared import opts, cmd_opts
|
28 |
+
import modules.shared as shared
|
29 |
+
from modules.sd_samplers import samplers, samplers_for_img2img
|
30 |
+
from modules.sd_hijack import model_hijack
|
31 |
+
import modules.ldsr_model
|
32 |
+
import modules.scripts
|
33 |
+
import modules.gfpgan_model
|
34 |
+
import modules.codeformer_model
|
35 |
+
import modules.styles
|
36 |
+
import modules.generation_parameters_copypaste
|
37 |
+
from modules import prompt_parser
|
38 |
+
from modules.images import save_image
|
39 |
+
import modules.textual_inversion.ui
|
40 |
+
|
41 |
+
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
42 |
+
mimetypes.init()
|
43 |
+
mimetypes.add_type('application/javascript', '.js')
|
44 |
+
|
45 |
+
|
46 |
+
if not cmd_opts.share and not cmd_opts.listen:
|
47 |
+
# fix gradio phoning home
|
48 |
+
gradio.utils.version_check = lambda: None
|
49 |
+
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
|
50 |
+
|
51 |
+
|
52 |
+
def gr_show(visible=True):
|
53 |
+
return {"visible": visible, "__type__": "update"}
|
54 |
+
|
55 |
+
|
56 |
+
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
57 |
+
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
|
58 |
+
|
59 |
+
css_hide_progressbar = """
|
60 |
+
.wrap .m-12 svg { display:none!important; }
|
61 |
+
.wrap .m-12::before { content:"Loading..." }
|
62 |
+
.progress-bar { display:none!important; }
|
63 |
+
.meta-text { display:none!important; }
|
64 |
+
"""
|
65 |
+
|
66 |
+
# Using constants for these since the variation selector isn't visible.
|
67 |
+
# Important that they exactly match script.js for tooltip to work.
|
68 |
+
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
|
69 |
+
reuse_symbol = '\u267b\ufe0f' # ♻️
|
70 |
+
art_symbol = '\U0001f3a8' # 🎨
|
71 |
+
paste_symbol = '\u2199\ufe0f' # ↙
|
72 |
+
folder_symbol = '\U0001f4c2' # 📂
|
73 |
+
|
74 |
+
def plaintext_to_html(text):
|
75 |
+
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
76 |
+
return text
|
77 |
+
|
78 |
+
|
79 |
+
def image_from_url_text(filedata):
|
80 |
+
if type(filedata) == list:
|
81 |
+
if len(filedata) == 0:
|
82 |
+
return None
|
83 |
+
|
84 |
+
filedata = filedata[0]
|
85 |
+
|
86 |
+
if filedata.startswith("data:image/png;base64,"):
|
87 |
+
filedata = filedata[len("data:image/png;base64,"):]
|
88 |
+
|
89 |
+
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
90 |
+
image = Image.open(io.BytesIO(filedata))
|
91 |
+
return image
|
92 |
+
|
93 |
+
|
94 |
+
def send_gradio_gallery_to_image(x):
|
95 |
+
if len(x) == 0:
|
96 |
+
return None
|
97 |
+
|
98 |
+
return image_from_url_text(x[0])
|
99 |
+
|
100 |
+
|
101 |
+
def save_files(js_data, images, index):
|
102 |
+
import csv
|
103 |
+
filenames = []
|
104 |
+
|
105 |
+
#quick dictionary to class object conversion. Its neccesary due apply_filename_pattern requiring it
|
106 |
+
class MyObject:
|
107 |
+
def __init__(self, d=None):
|
108 |
+
if d is not None:
|
109 |
+
for key, value in d.items():
|
110 |
+
setattr(self, key, value)
|
111 |
+
|
112 |
+
data = json.loads(js_data)
|
113 |
+
|
114 |
+
p = MyObject(data)
|
115 |
+
path = opts.outdir_save
|
116 |
+
save_to_dirs = opts.use_save_to_dirs_for_ui
|
117 |
+
extension: str = opts.samples_format
|
118 |
+
start_index = 0
|
119 |
+
|
120 |
+
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
121 |
+
|
122 |
+
images = [images[index]]
|
123 |
+
start_index = index
|
124 |
+
|
125 |
+
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
126 |
+
at_start = file.tell() == 0
|
127 |
+
writer = csv.writer(file)
|
128 |
+
if at_start:
|
129 |
+
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
130 |
+
|
131 |
+
for image_index, filedata in enumerate(images, start_index):
|
132 |
+
if filedata.startswith("data:image/png;base64,"):
|
133 |
+
filedata = filedata[len("data:image/png;base64,"):]
|
134 |
+
|
135 |
+
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
|
136 |
+
|
137 |
+
is_grid = image_index < p.index_of_first_image
|
138 |
+
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
139 |
+
|
140 |
+
fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
141 |
+
|
142 |
+
filename = os.path.relpath(fullfn, path)
|
143 |
+
filenames.append(filename)
|
144 |
+
|
145 |
+
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
146 |
+
|
147 |
+
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
148 |
+
|
149 |
+
|
150 |
+
def wrap_gradio_call(func, extra_outputs=None):
|
151 |
+
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
152 |
+
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
153 |
+
if run_memmon:
|
154 |
+
shared.mem_mon.monitor()
|
155 |
+
t = time.perf_counter()
|
156 |
+
|
157 |
+
try:
|
158 |
+
res = list(func(*args, **kwargs))
|
159 |
+
except Exception as e:
|
160 |
+
print("Error completing request", file=sys.stderr)
|
161 |
+
print("Arguments:", args, kwargs, file=sys.stderr)
|
162 |
+
print(traceback.format_exc(), file=sys.stderr)
|
163 |
+
|
164 |
+
shared.state.job = ""
|
165 |
+
shared.state.job_count = 0
|
166 |
+
|
167 |
+
if extra_outputs_array is None:
|
168 |
+
extra_outputs_array = [None, '']
|
169 |
+
|
170 |
+
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
171 |
+
|
172 |
+
elapsed = time.perf_counter() - t
|
173 |
+
elapsed_m = int(elapsed // 60)
|
174 |
+
elapsed_s = elapsed % 60
|
175 |
+
elapsed_text = f"{elapsed_s:.2f}s"
|
176 |
+
if (elapsed_m > 0):
|
177 |
+
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
178 |
+
|
179 |
+
if run_memmon:
|
180 |
+
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
181 |
+
active_peak = mem_stats['active_peak']
|
182 |
+
reserved_peak = mem_stats['reserved_peak']
|
183 |
+
sys_peak = mem_stats['system_peak']
|
184 |
+
sys_total = mem_stats['total']
|
185 |
+
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
186 |
+
|
187 |
+
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
188 |
+
else:
|
189 |
+
vram_html = ''
|
190 |
+
|
191 |
+
# last item is always HTML
|
192 |
+
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
193 |
+
|
194 |
+
shared.state.interrupted = False
|
195 |
+
shared.state.job_count = 0
|
196 |
+
|
197 |
+
return tuple(res)
|
198 |
+
|
199 |
+
return f
|
200 |
+
|
201 |
+
|
202 |
+
def check_progress_call(id_part):
|
203 |
+
if shared.state.job_count == 0:
|
204 |
+
return "", gr_show(False), gr_show(False), gr_show(False)
|
205 |
+
|
206 |
+
progress = 0
|
207 |
+
|
208 |
+
if shared.state.job_count > 0:
|
209 |
+
progress += shared.state.job_no / shared.state.job_count
|
210 |
+
if shared.state.sampling_steps > 0:
|
211 |
+
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
212 |
+
|
213 |
+
progress = min(progress, 1)
|
214 |
+
|
215 |
+
progressbar = ""
|
216 |
+
if opts.show_progressbar:
|
217 |
+
progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>"""
|
218 |
+
|
219 |
+
image = gr_show(False)
|
220 |
+
preview_visibility = gr_show(False)
|
221 |
+
|
222 |
+
if opts.show_progress_every_n_steps > 0:
|
223 |
+
if shared.parallel_processing_allowed:
|
224 |
+
|
225 |
+
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
|
226 |
+
shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
|
227 |
+
shared.state.current_image_sampling_step = shared.state.sampling_step
|
228 |
+
|
229 |
+
image = shared.state.current_image
|
230 |
+
|
231 |
+
if image is None:
|
232 |
+
image = gr.update(value=None)
|
233 |
+
else:
|
234 |
+
preview_visibility = gr_show(True)
|
235 |
+
|
236 |
+
if shared.state.textinfo is not None:
|
237 |
+
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
|
238 |
+
else:
|
239 |
+
textinfo_result = gr_show(False)
|
240 |
+
|
241 |
+
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
|
242 |
+
|
243 |
+
|
244 |
+
def check_progress_call_initial(id_part):
|
245 |
+
shared.state.job_count = -1
|
246 |
+
shared.state.current_latent = None
|
247 |
+
shared.state.current_image = None
|
248 |
+
shared.state.textinfo = None
|
249 |
+
|
250 |
+
return check_progress_call(id_part)
|
251 |
+
|
252 |
+
|
253 |
+
def roll_artist(prompt):
|
254 |
+
allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories])
|
255 |
+
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
|
256 |
+
|
257 |
+
return prompt + ", " + artist.name if prompt != '' else artist.name
|
258 |
+
|
259 |
+
|
260 |
+
def visit(x, func, path=""):
|
261 |
+
if hasattr(x, 'children'):
|
262 |
+
for c in x.children:
|
263 |
+
visit(c, func, path)
|
264 |
+
elif x.label is not None:
|
265 |
+
func(path + "/" + str(x.label), x)
|
266 |
+
|
267 |
+
|
268 |
+
def add_style(name: str, prompt: str, negative_prompt: str):
|
269 |
+
if name is None:
|
270 |
+
return [gr_show(), gr_show()]
|
271 |
+
|
272 |
+
style = modules.styles.PromptStyle(name, prompt, negative_prompt)
|
273 |
+
shared.prompt_styles.styles[style.name] = style
|
274 |
+
# Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
|
275 |
+
# reserialize all styles every time we save them
|
276 |
+
shared.prompt_styles.save_styles(shared.styles_filename)
|
277 |
+
|
278 |
+
return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
|
279 |
+
|
280 |
+
|
281 |
+
def apply_styles(prompt, prompt_neg, style1_name, style2_name):
|
282 |
+
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
|
283 |
+
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
|
284 |
+
|
285 |
+
return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
|
286 |
+
|
287 |
+
|
288 |
+
def interrogate(image):
|
289 |
+
prompt = shared.interrogator.interrogate(image)
|
290 |
+
|
291 |
+
return gr_show(True) if prompt is None else prompt
|
292 |
+
|
293 |
+
|
294 |
+
def create_seed_inputs():
|
295 |
+
with gr.Row():
|
296 |
+
with gr.Box():
|
297 |
+
with gr.Row(elem_id='seed_row'):
|
298 |
+
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1)
|
299 |
+
seed.style(container=False)
|
300 |
+
random_seed = gr.Button(random_symbol, elem_id='random_seed')
|
301 |
+
reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed')
|
302 |
+
|
303 |
+
with gr.Box(elem_id='subseed_show_box'):
|
304 |
+
seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False)
|
305 |
+
|
306 |
+
# Components to show/hide based on the 'Extra' checkbox
|
307 |
+
seed_extras = []
|
308 |
+
|
309 |
+
with gr.Row(visible=False) as seed_extra_row_1:
|
310 |
+
seed_extras.append(seed_extra_row_1)
|
311 |
+
with gr.Box():
|
312 |
+
with gr.Row(elem_id='subseed_row'):
|
313 |
+
subseed = gr.Number(label='Variation seed', value=-1)
|
314 |
+
subseed.style(container=False)
|
315 |
+
random_subseed = gr.Button(random_symbol, elem_id='random_subseed')
|
316 |
+
reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed')
|
317 |
+
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01)
|
318 |
+
|
319 |
+
with gr.Row(visible=False) as seed_extra_row_2:
|
320 |
+
seed_extras.append(seed_extra_row_2)
|
321 |
+
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0)
|
322 |
+
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0)
|
323 |
+
|
324 |
+
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
|
325 |
+
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
|
326 |
+
|
327 |
+
def change_visibility(show):
|
328 |
+
return {comp: gr_show(show) for comp in seed_extras}
|
329 |
+
|
330 |
+
seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
|
331 |
+
|
332 |
+
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
|
333 |
+
|
334 |
+
|
335 |
+
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
|
336 |
+
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
|
337 |
+
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
|
338 |
+
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
|
339 |
+
def copy_seed(gen_info_string: str, index):
|
340 |
+
res = -1
|
341 |
+
|
342 |
+
try:
|
343 |
+
gen_info = json.loads(gen_info_string)
|
344 |
+
index -= gen_info.get('index_of_first_image', 0)
|
345 |
+
|
346 |
+
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
|
347 |
+
all_subseeds = gen_info.get('all_subseeds', [-1])
|
348 |
+
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
|
349 |
+
else:
|
350 |
+
all_seeds = gen_info.get('all_seeds', [-1])
|
351 |
+
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
352 |
+
|
353 |
+
except json.decoder.JSONDecodeError as e:
|
354 |
+
if gen_info_string != '':
|
355 |
+
print("Error parsing JSON generation info:", file=sys.stderr)
|
356 |
+
print(gen_info_string, file=sys.stderr)
|
357 |
+
|
358 |
+
return [res, gr_show(False)]
|
359 |
+
|
360 |
+
reuse_seed.click(
|
361 |
+
fn=copy_seed,
|
362 |
+
_js="(x, y) => [x, selected_gallery_index()]",
|
363 |
+
show_progress=False,
|
364 |
+
inputs=[generation_info, dummy_component],
|
365 |
+
outputs=[seed, dummy_component]
|
366 |
+
)
|
367 |
+
|
368 |
+
|
369 |
+
def update_token_counter(text, steps):
|
370 |
+
try:
|
371 |
+
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
372 |
+
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
|
373 |
+
|
374 |
+
except Exception:
|
375 |
+
# a parsing error can happen here during typing, and we don't want to bother the user with
|
376 |
+
# messages related to it in console
|
377 |
+
prompt_schedules = [[[steps, text]]]
|
378 |
+
|
379 |
+
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
|
380 |
+
prompts = [prompt_text for step, prompt_text in flat_prompts]
|
381 |
+
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
|
382 |
+
style_class = ' class="red"' if (token_count > max_length) else ""
|
383 |
+
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
384 |
+
|
385 |
+
|
386 |
+
def create_toprow(is_img2img):
|
387 |
+
id_part = "img2img" if is_img2img else "txt2img"
|
388 |
+
|
389 |
+
with gr.Row(elem_id="toprow"):
|
390 |
+
with gr.Column(scale=4):
|
391 |
+
with gr.Row():
|
392 |
+
with gr.Column(scale=80):
|
393 |
+
with gr.Row():
|
394 |
+
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
|
395 |
+
|
396 |
+
with gr.Column(scale=1, elem_id="roll_col"):
|
397 |
+
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
398 |
+
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
399 |
+
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
400 |
+
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
401 |
+
|
402 |
+
with gr.Column(scale=10, elem_id="style_pos_col"):
|
403 |
+
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
404 |
+
|
405 |
+
with gr.Row():
|
406 |
+
with gr.Column(scale=8):
|
407 |
+
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
|
408 |
+
|
409 |
+
with gr.Column(scale=1, elem_id="style_neg_col"):
|
410 |
+
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
411 |
+
|
412 |
+
with gr.Column(scale=1):
|
413 |
+
with gr.Row():
|
414 |
+
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
|
415 |
+
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
416 |
+
|
417 |
+
interrupt.click(
|
418 |
+
fn=lambda: shared.state.interrupt(),
|
419 |
+
inputs=[],
|
420 |
+
outputs=[],
|
421 |
+
)
|
422 |
+
|
423 |
+
with gr.Row():
|
424 |
+
if is_img2img:
|
425 |
+
interrogate = gr.Button('Interrogate', elem_id="interrogate")
|
426 |
+
else:
|
427 |
+
interrogate = None
|
428 |
+
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
|
429 |
+
save_style = gr.Button('Create style', elem_id="style_create")
|
430 |
+
|
431 |
+
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
|
432 |
+
|
433 |
+
|
434 |
+
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
435 |
+
if textinfo is None:
|
436 |
+
textinfo = gr.HTML(visible=False)
|
437 |
+
|
438 |
+
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
|
439 |
+
check_progress.click(
|
440 |
+
fn=lambda: check_progress_call(id_part),
|
441 |
+
show_progress=False,
|
442 |
+
inputs=[],
|
443 |
+
outputs=[progressbar, preview, preview, textinfo],
|
444 |
+
)
|
445 |
+
|
446 |
+
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
|
447 |
+
check_progress_initial.click(
|
448 |
+
fn=lambda: check_progress_call_initial(id_part),
|
449 |
+
show_progress=False,
|
450 |
+
inputs=[],
|
451 |
+
outputs=[progressbar, preview, preview, textinfo],
|
452 |
+
)
|
453 |
+
|
454 |
+
|
455 |
+
def create_ui(wrap_gradio_gpu_call):
|
456 |
+
import modules.img2img
|
457 |
+
import modules.txt2img
|
458 |
+
|
459 |
+
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
460 |
+
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
|
461 |
+
dummy_component = gr.Label(visible=False)
|
462 |
+
|
463 |
+
with gr.Row(elem_id='txt2img_progress_row'):
|
464 |
+
with gr.Column(scale=1):
|
465 |
+
pass
|
466 |
+
|
467 |
+
with gr.Column(scale=1):
|
468 |
+
progressbar = gr.HTML(elem_id="txt2img_progressbar")
|
469 |
+
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
|
470 |
+
setup_progressbar(progressbar, txt2img_preview, 'txt2img')
|
471 |
+
|
472 |
+
with gr.Row().style(equal_height=False):
|
473 |
+
with gr.Column(variant='panel'):
|
474 |
+
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
|
475 |
+
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
|
476 |
+
|
477 |
+
with gr.Group():
|
478 |
+
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
479 |
+
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
480 |
+
|
481 |
+
with gr.Row():
|
482 |
+
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
|
483 |
+
tiling = gr.Checkbox(label='Tiling', value=False)
|
484 |
+
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
|
485 |
+
|
486 |
+
with gr.Row(visible=False) as hr_options:
|
487 |
+
scale_latent = gr.Checkbox(label='Scale latent', value=False)
|
488 |
+
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
|
489 |
+
|
490 |
+
with gr.Row():
|
491 |
+
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
|
492 |
+
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
|
493 |
+
|
494 |
+
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
|
495 |
+
|
496 |
+
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
|
497 |
+
|
498 |
+
with gr.Group():
|
499 |
+
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
|
500 |
+
|
501 |
+
with gr.Column(variant='panel'):
|
502 |
+
|
503 |
+
with gr.Group():
|
504 |
+
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
|
505 |
+
txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
|
506 |
+
|
507 |
+
with gr.Group():
|
508 |
+
with gr.Row():
|
509 |
+
save = gr.Button('Save')
|
510 |
+
send_to_img2img = gr.Button('Send to img2img')
|
511 |
+
send_to_inpaint = gr.Button('Send to inpaint')
|
512 |
+
send_to_extras = gr.Button('Send to extras')
|
513 |
+
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
514 |
+
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
515 |
+
|
516 |
+
with gr.Group():
|
517 |
+
html_info = gr.HTML()
|
518 |
+
generation_info = gr.Textbox(visible=False)
|
519 |
+
|
520 |
+
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
521 |
+
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
522 |
+
|
523 |
+
txt2img_args = dict(
|
524 |
+
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
|
525 |
+
_js="submit",
|
526 |
+
inputs=[
|
527 |
+
txt2img_prompt,
|
528 |
+
txt2img_negative_prompt,
|
529 |
+
txt2img_prompt_style,
|
530 |
+
txt2img_prompt_style2,
|
531 |
+
steps,
|
532 |
+
sampler_index,
|
533 |
+
restore_faces,
|
534 |
+
tiling,
|
535 |
+
batch_count,
|
536 |
+
batch_size,
|
537 |
+
cfg_scale,
|
538 |
+
seed,
|
539 |
+
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
540 |
+
height,
|
541 |
+
width,
|
542 |
+
enable_hr,
|
543 |
+
scale_latent,
|
544 |
+
denoising_strength,
|
545 |
+
] + custom_inputs,
|
546 |
+
outputs=[
|
547 |
+
txt2img_gallery,
|
548 |
+
generation_info,
|
549 |
+
html_info
|
550 |
+
],
|
551 |
+
show_progress=False,
|
552 |
+
)
|
553 |
+
|
554 |
+
txt2img_prompt.submit(**txt2img_args)
|
555 |
+
submit.click(**txt2img_args)
|
556 |
+
|
557 |
+
enable_hr.change(
|
558 |
+
fn=lambda x: gr_show(x),
|
559 |
+
inputs=[enable_hr],
|
560 |
+
outputs=[hr_options],
|
561 |
+
)
|
562 |
+
|
563 |
+
save.click(
|
564 |
+
fn=wrap_gradio_call(save_files),
|
565 |
+
_js="(x, y, z) => [x, y, selected_gallery_index()]",
|
566 |
+
inputs=[
|
567 |
+
generation_info,
|
568 |
+
txt2img_gallery,
|
569 |
+
html_info,
|
570 |
+
],
|
571 |
+
outputs=[
|
572 |
+
html_info,
|
573 |
+
html_info,
|
574 |
+
html_info,
|
575 |
+
]
|
576 |
+
)
|
577 |
+
|
578 |
+
roll.click(
|
579 |
+
fn=roll_artist,
|
580 |
+
_js="update_txt2img_tokens",
|
581 |
+
inputs=[
|
582 |
+
txt2img_prompt,
|
583 |
+
],
|
584 |
+
outputs=[
|
585 |
+
txt2img_prompt,
|
586 |
+
]
|
587 |
+
)
|
588 |
+
|
589 |
+
txt2img_paste_fields = [
|
590 |
+
(txt2img_prompt, "Prompt"),
|
591 |
+
(txt2img_negative_prompt, "Negative prompt"),
|
592 |
+
(steps, "Steps"),
|
593 |
+
(sampler_index, "Sampler"),
|
594 |
+
(restore_faces, "Face restoration"),
|
595 |
+
(cfg_scale, "CFG scale"),
|
596 |
+
(seed, "Seed"),
|
597 |
+
(width, "Size-1"),
|
598 |
+
(height, "Size-2"),
|
599 |
+
(batch_size, "Batch size"),
|
600 |
+
(subseed, "Variation seed"),
|
601 |
+
(subseed_strength, "Variation seed strength"),
|
602 |
+
(seed_resize_from_w, "Seed resize from-1"),
|
603 |
+
(seed_resize_from_h, "Seed resize from-2"),
|
604 |
+
(denoising_strength, "Denoising strength"),
|
605 |
+
(enable_hr, lambda d: "Denoising strength" in d),
|
606 |
+
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
607 |
+
]
|
608 |
+
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
|
609 |
+
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
610 |
+
|
611 |
+
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
612 |
+
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
|
613 |
+
|
614 |
+
with gr.Row(elem_id='img2img_progress_row'):
|
615 |
+
with gr.Column(scale=1):
|
616 |
+
pass
|
617 |
+
|
618 |
+
with gr.Column(scale=1):
|
619 |
+
progressbar = gr.HTML(elem_id="img2img_progressbar")
|
620 |
+
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
|
621 |
+
setup_progressbar(progressbar, img2img_preview, 'img2img')
|
622 |
+
|
623 |
+
with gr.Row().style(equal_height=False):
|
624 |
+
with gr.Column(variant='panel'):
|
625 |
+
|
626 |
+
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
|
627 |
+
with gr.TabItem('img2img', id='img2img'):
|
628 |
+
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
|
629 |
+
|
630 |
+
with gr.TabItem('Inpaint', id='inpaint'):
|
631 |
+
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
|
632 |
+
|
633 |
+
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
|
634 |
+
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
|
635 |
+
|
636 |
+
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
|
637 |
+
|
638 |
+
with gr.Row():
|
639 |
+
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
|
640 |
+
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
|
641 |
+
|
642 |
+
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
|
643 |
+
|
644 |
+
with gr.Row():
|
645 |
+
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
|
646 |
+
inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32)
|
647 |
+
|
648 |
+
with gr.TabItem('Batch img2img', id='batch'):
|
649 |
+
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
650 |
+
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
|
651 |
+
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
|
652 |
+
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
|
653 |
+
|
654 |
+
with gr.Row():
|
655 |
+
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
|
656 |
+
|
657 |
+
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
|
658 |
+
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
|
659 |
+
|
660 |
+
with gr.Group():
|
661 |
+
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
662 |
+
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
663 |
+
|
664 |
+
with gr.Row():
|
665 |
+
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
|
666 |
+
tiling = gr.Checkbox(label='Tiling', value=False)
|
667 |
+
|
668 |
+
with gr.Row():
|
669 |
+
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
|
670 |
+
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
|
671 |
+
|
672 |
+
with gr.Group():
|
673 |
+
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
|
674 |
+
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
|
675 |
+
|
676 |
+
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
|
677 |
+
|
678 |
+
with gr.Group():
|
679 |
+
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
|
680 |
+
|
681 |
+
with gr.Column(variant='panel'):
|
682 |
+
|
683 |
+
with gr.Group():
|
684 |
+
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
|
685 |
+
img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
|
686 |
+
|
687 |
+
with gr.Group():
|
688 |
+
with gr.Row():
|
689 |
+
save = gr.Button('Save')
|
690 |
+
img2img_send_to_img2img = gr.Button('Send to img2img')
|
691 |
+
img2img_send_to_inpaint = gr.Button('Send to inpaint')
|
692 |
+
img2img_send_to_extras = gr.Button('Send to extras')
|
693 |
+
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
694 |
+
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
695 |
+
|
696 |
+
with gr.Group():
|
697 |
+
html_info = gr.HTML()
|
698 |
+
generation_info = gr.Textbox(visible=False)
|
699 |
+
|
700 |
+
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
701 |
+
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
702 |
+
|
703 |
+
mask_mode.change(
|
704 |
+
lambda mode, img: {
|
705 |
+
init_img_with_mask: gr_show(mode == 0),
|
706 |
+
init_img_inpaint: gr_show(mode == 1),
|
707 |
+
init_mask_inpaint: gr_show(mode == 1),
|
708 |
+
},
|
709 |
+
inputs=[mask_mode, init_img_with_mask],
|
710 |
+
outputs=[
|
711 |
+
init_img_with_mask,
|
712 |
+
init_img_inpaint,
|
713 |
+
init_mask_inpaint,
|
714 |
+
],
|
715 |
+
)
|
716 |
+
|
717 |
+
img2img_args = dict(
|
718 |
+
fn=wrap_gradio_gpu_call(modules.img2img.img2img),
|
719 |
+
_js="submit_img2img",
|
720 |
+
inputs=[
|
721 |
+
dummy_component,
|
722 |
+
img2img_prompt,
|
723 |
+
img2img_negative_prompt,
|
724 |
+
img2img_prompt_style,
|
725 |
+
img2img_prompt_style2,
|
726 |
+
init_img,
|
727 |
+
init_img_with_mask,
|
728 |
+
init_img_inpaint,
|
729 |
+
init_mask_inpaint,
|
730 |
+
mask_mode,
|
731 |
+
steps,
|
732 |
+
sampler_index,
|
733 |
+
mask_blur,
|
734 |
+
inpainting_fill,
|
735 |
+
restore_faces,
|
736 |
+
tiling,
|
737 |
+
batch_count,
|
738 |
+
batch_size,
|
739 |
+
cfg_scale,
|
740 |
+
denoising_strength,
|
741 |
+
seed,
|
742 |
+
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
743 |
+
height,
|
744 |
+
width,
|
745 |
+
resize_mode,
|
746 |
+
inpaint_full_res,
|
747 |
+
inpaint_full_res_padding,
|
748 |
+
inpainting_mask_invert,
|
749 |
+
img2img_batch_input_dir,
|
750 |
+
img2img_batch_output_dir,
|
751 |
+
] + custom_inputs,
|
752 |
+
outputs=[
|
753 |
+
img2img_gallery,
|
754 |
+
generation_info,
|
755 |
+
html_info
|
756 |
+
],
|
757 |
+
show_progress=False,
|
758 |
+
)
|
759 |
+
|
760 |
+
img2img_prompt.submit(**img2img_args)
|
761 |
+
submit.click(**img2img_args)
|
762 |
+
|
763 |
+
img2img_interrogate.click(
|
764 |
+
fn=interrogate,
|
765 |
+
inputs=[init_img],
|
766 |
+
outputs=[img2img_prompt],
|
767 |
+
)
|
768 |
+
|
769 |
+
save.click(
|
770 |
+
fn=wrap_gradio_call(save_files),
|
771 |
+
_js="(x, y, z) => [x, y, selected_gallery_index()]",
|
772 |
+
inputs=[
|
773 |
+
generation_info,
|
774 |
+
img2img_gallery,
|
775 |
+
html_info
|
776 |
+
],
|
777 |
+
outputs=[
|
778 |
+
html_info,
|
779 |
+
html_info,
|
780 |
+
html_info,
|
781 |
+
]
|
782 |
+
)
|
783 |
+
|
784 |
+
roll.click(
|
785 |
+
fn=roll_artist,
|
786 |
+
_js="update_img2img_tokens",
|
787 |
+
inputs=[
|
788 |
+
img2img_prompt,
|
789 |
+
],
|
790 |
+
outputs=[
|
791 |
+
img2img_prompt,
|
792 |
+
]
|
793 |
+
)
|
794 |
+
|
795 |
+
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
|
796 |
+
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
|
797 |
+
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
|
798 |
+
|
799 |
+
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
|
800 |
+
button.click(
|
801 |
+
fn=add_style,
|
802 |
+
_js="ask_for_style_name",
|
803 |
+
# Have to pass empty dummy component here, because the JavaScript and Python function have to accept
|
804 |
+
# the same number of parameters, but we only know the style-name after the JavaScript prompt
|
805 |
+
inputs=[dummy_component, prompt, negative_prompt],
|
806 |
+
outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
|
807 |
+
)
|
808 |
+
|
809 |
+
for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
|
810 |
+
button.click(
|
811 |
+
fn=apply_styles,
|
812 |
+
_js=js_func,
|
813 |
+
inputs=[prompt, negative_prompt, style1, style2],
|
814 |
+
outputs=[prompt, negative_prompt, style1, style2],
|
815 |
+
)
|
816 |
+
|
817 |
+
img2img_paste_fields = [
|
818 |
+
(img2img_prompt, "Prompt"),
|
819 |
+
(img2img_negative_prompt, "Negative prompt"),
|
820 |
+
(steps, "Steps"),
|
821 |
+
(sampler_index, "Sampler"),
|
822 |
+
(restore_faces, "Face restoration"),
|
823 |
+
(cfg_scale, "CFG scale"),
|
824 |
+
(seed, "Seed"),
|
825 |
+
(width, "Size-1"),
|
826 |
+
(height, "Size-2"),
|
827 |
+
(batch_size, "Batch size"),
|
828 |
+
(subseed, "Variation seed"),
|
829 |
+
(subseed_strength, "Variation seed strength"),
|
830 |
+
(seed_resize_from_w, "Seed resize from-1"),
|
831 |
+
(seed_resize_from_h, "Seed resize from-2"),
|
832 |
+
(denoising_strength, "Denoising strength"),
|
833 |
+
]
|
834 |
+
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
|
835 |
+
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
836 |
+
|
837 |
+
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
838 |
+
with gr.Row().style(equal_height=False):
|
839 |
+
with gr.Column(variant='panel'):
|
840 |
+
with gr.Tabs(elem_id="mode_extras"):
|
841 |
+
with gr.TabItem('Single Image'):
|
842 |
+
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
|
843 |
+
|
844 |
+
with gr.TabItem('Batch Process'):
|
845 |
+
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
|
846 |
+
|
847 |
+
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
|
848 |
+
|
849 |
+
with gr.Group():
|
850 |
+
extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
851 |
+
|
852 |
+
with gr.Group():
|
853 |
+
extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
854 |
+
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
|
855 |
+
|
856 |
+
with gr.Group():
|
857 |
+
gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan)
|
858 |
+
|
859 |
+
with gr.Group():
|
860 |
+
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
|
861 |
+
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
|
862 |
+
|
863 |
+
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
|
864 |
+
|
865 |
+
with gr.Column(variant='panel'):
|
866 |
+
result_images = gr.Gallery(label="Result", show_label=False)
|
867 |
+
html_info_x = gr.HTML()
|
868 |
+
html_info = gr.HTML()
|
869 |
+
extras_send_to_img2img = gr.Button('Send to img2img')
|
870 |
+
extras_send_to_inpaint = gr.Button('Send to inpaint')
|
871 |
+
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
|
872 |
+
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
873 |
+
|
874 |
+
submit.click(
|
875 |
+
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
876 |
+
_js="get_extras_tab_index",
|
877 |
+
inputs=[
|
878 |
+
dummy_component,
|
879 |
+
extras_image,
|
880 |
+
image_batch,
|
881 |
+
gfpgan_visibility,
|
882 |
+
codeformer_visibility,
|
883 |
+
codeformer_weight,
|
884 |
+
upscaling_resize,
|
885 |
+
extras_upscaler_1,
|
886 |
+
extras_upscaler_2,
|
887 |
+
extras_upscaler_2_visibility,
|
888 |
+
],
|
889 |
+
outputs=[
|
890 |
+
result_images,
|
891 |
+
html_info_x,
|
892 |
+
html_info,
|
893 |
+
]
|
894 |
+
)
|
895 |
+
|
896 |
+
extras_send_to_img2img.click(
|
897 |
+
fn=lambda x: image_from_url_text(x),
|
898 |
+
_js="extract_image_from_gallery_img2img",
|
899 |
+
inputs=[result_images],
|
900 |
+
outputs=[init_img],
|
901 |
+
)
|
902 |
+
|
903 |
+
extras_send_to_inpaint.click(
|
904 |
+
fn=lambda x: image_from_url_text(x),
|
905 |
+
_js="extract_image_from_gallery_img2img",
|
906 |
+
inputs=[result_images],
|
907 |
+
outputs=[init_img_with_mask],
|
908 |
+
)
|
909 |
+
|
910 |
+
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
|
911 |
+
with gr.Row().style(equal_height=False):
|
912 |
+
with gr.Column(variant='panel'):
|
913 |
+
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
|
914 |
+
|
915 |
+
with gr.Column(variant='panel'):
|
916 |
+
html = gr.HTML()
|
917 |
+
generation_info = gr.Textbox(visible=False)
|
918 |
+
html2 = gr.HTML()
|
919 |
+
|
920 |
+
with gr.Row():
|
921 |
+
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
922 |
+
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
923 |
+
|
924 |
+
image.change(
|
925 |
+
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
926 |
+
inputs=[image],
|
927 |
+
outputs=[html, generation_info, html2],
|
928 |
+
)
|
929 |
+
|
930 |
+
with gr.Blocks() as modelmerger_interface:
|
931 |
+
with gr.Row().style(equal_height=False):
|
932 |
+
with gr.Column(variant='panel'):
|
933 |
+
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
934 |
+
|
935 |
+
with gr.Row():
|
936 |
+
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
|
937 |
+
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
|
938 |
+
custom_name = gr.Textbox(label="Custom Name (Optional)")
|
939 |
+
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
|
940 |
+
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
|
941 |
+
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
|
942 |
+
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
943 |
+
|
944 |
+
with gr.Column(variant='panel'):
|
945 |
+
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
946 |
+
|
947 |
+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
948 |
+
|
949 |
+
with gr.Blocks() as textual_inversion_interface:
|
950 |
+
with gr.Row().style(equal_height=False):
|
951 |
+
with gr.Column():
|
952 |
+
with gr.Group():
|
953 |
+
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
954 |
+
|
955 |
+
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
|
956 |
+
|
957 |
+
new_embedding_name = gr.Textbox(label="Name")
|
958 |
+
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
959 |
+
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
960 |
+
|
961 |
+
with gr.Row():
|
962 |
+
with gr.Column(scale=3):
|
963 |
+
gr.HTML(value="")
|
964 |
+
|
965 |
+
with gr.Column():
|
966 |
+
create_embedding = gr.Button(value="Create", variant='primary')
|
967 |
+
|
968 |
+
with gr.Group():
|
969 |
+
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
|
970 |
+
|
971 |
+
process_src = gr.Textbox(label='Source directory')
|
972 |
+
process_dst = gr.Textbox(label='Destination directory')
|
973 |
+
|
974 |
+
with gr.Row():
|
975 |
+
process_flip = gr.Checkbox(label='Flip')
|
976 |
+
process_split = gr.Checkbox(label='Split into two')
|
977 |
+
process_caption = gr.Checkbox(label='Add caption')
|
978 |
+
|
979 |
+
with gr.Row():
|
980 |
+
with gr.Column(scale=3):
|
981 |
+
gr.HTML(value="")
|
982 |
+
|
983 |
+
with gr.Column():
|
984 |
+
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
985 |
+
|
986 |
+
with gr.Group():
|
987 |
+
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
|
988 |
+
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
989 |
+
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
990 |
+
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
991 |
+
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
992 |
+
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
993 |
+
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
994 |
+
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
995 |
+
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
996 |
+
|
997 |
+
with gr.Row():
|
998 |
+
with gr.Column(scale=2):
|
999 |
+
gr.HTML(value="")
|
1000 |
+
|
1001 |
+
with gr.Column():
|
1002 |
+
with gr.Row():
|
1003 |
+
interrupt_training = gr.Button(value="Interrupt")
|
1004 |
+
train_embedding = gr.Button(value="Train", variant='primary')
|
1005 |
+
|
1006 |
+
with gr.Column():
|
1007 |
+
progressbar = gr.HTML(elem_id="ti_progressbar")
|
1008 |
+
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
1009 |
+
|
1010 |
+
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
|
1011 |
+
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
|
1012 |
+
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
1013 |
+
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
1014 |
+
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
|
1015 |
+
|
1016 |
+
create_embedding.click(
|
1017 |
+
fn=modules.textual_inversion.ui.create_embedding,
|
1018 |
+
inputs=[
|
1019 |
+
new_embedding_name,
|
1020 |
+
initialization_text,
|
1021 |
+
nvpt,
|
1022 |
+
],
|
1023 |
+
outputs=[
|
1024 |
+
train_embedding_name,
|
1025 |
+
ti_output,
|
1026 |
+
ti_outcome,
|
1027 |
+
]
|
1028 |
+
)
|
1029 |
+
|
1030 |
+
run_preprocess.click(
|
1031 |
+
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
|
1032 |
+
_js="start_training_textual_inversion",
|
1033 |
+
inputs=[
|
1034 |
+
process_src,
|
1035 |
+
process_dst,
|
1036 |
+
process_flip,
|
1037 |
+
process_split,
|
1038 |
+
process_caption,
|
1039 |
+
],
|
1040 |
+
outputs=[
|
1041 |
+
ti_output,
|
1042 |
+
ti_outcome,
|
1043 |
+
],
|
1044 |
+
)
|
1045 |
+
|
1046 |
+
train_embedding.click(
|
1047 |
+
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
|
1048 |
+
_js="start_training_textual_inversion",
|
1049 |
+
inputs=[
|
1050 |
+
train_embedding_name,
|
1051 |
+
learn_rate,
|
1052 |
+
dataset_directory,
|
1053 |
+
log_directory,
|
1054 |
+
steps,
|
1055 |
+
create_image_every,
|
1056 |
+
save_embedding_every,
|
1057 |
+
template_file,
|
1058 |
+
],
|
1059 |
+
outputs=[
|
1060 |
+
ti_output,
|
1061 |
+
ti_outcome,
|
1062 |
+
]
|
1063 |
+
)
|
1064 |
+
|
1065 |
+
interrupt_training.click(
|
1066 |
+
fn=lambda: shared.state.interrupt(),
|
1067 |
+
inputs=[],
|
1068 |
+
outputs=[],
|
1069 |
+
)
|
1070 |
+
|
1071 |
+
def create_setting_component(key):
|
1072 |
+
def fun():
|
1073 |
+
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
1074 |
+
|
1075 |
+
info = opts.data_labels[key]
|
1076 |
+
t = type(info.default)
|
1077 |
+
|
1078 |
+
args = info.component_args() if callable(info.component_args) else info.component_args
|
1079 |
+
|
1080 |
+
if info.component is not None:
|
1081 |
+
comp = info.component
|
1082 |
+
elif t == str:
|
1083 |
+
comp = gr.Textbox
|
1084 |
+
elif t == int:
|
1085 |
+
comp = gr.Number
|
1086 |
+
elif t == bool:
|
1087 |
+
comp = gr.Checkbox
|
1088 |
+
else:
|
1089 |
+
raise Exception(f'bad options item type: {str(t)} for key {key}')
|
1090 |
+
|
1091 |
+
return comp(label=info.label, value=fun, **(args or {}))
|
1092 |
+
|
1093 |
+
components = []
|
1094 |
+
component_dict = {}
|
1095 |
+
|
1096 |
+
def open_folder(f):
|
1097 |
+
if not shared.cmd_opts.hide_ui_dir_config:
|
1098 |
+
path = os.path.normpath(f)
|
1099 |
+
if platform.system() == "Windows":
|
1100 |
+
os.startfile(path)
|
1101 |
+
elif platform.system() == "Darwin":
|
1102 |
+
sp.Popen(["open", path])
|
1103 |
+
else:
|
1104 |
+
sp.Popen(["xdg-open", path])
|
1105 |
+
|
1106 |
+
def run_settings(*args):
|
1107 |
+
changed = 0
|
1108 |
+
|
1109 |
+
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
1110 |
+
if not opts.same_type(value, opts.data_labels[key].default):
|
1111 |
+
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
|
1112 |
+
|
1113 |
+
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
1114 |
+
comp_args = opts.data_labels[key].component_args
|
1115 |
+
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
1116 |
+
continue
|
1117 |
+
|
1118 |
+
oldval = opts.data.get(key, None)
|
1119 |
+
opts.data[key] = value
|
1120 |
+
|
1121 |
+
if oldval != value:
|
1122 |
+
if opts.data_labels[key].onchange is not None:
|
1123 |
+
opts.data_labels[key].onchange()
|
1124 |
+
|
1125 |
+
changed += 1
|
1126 |
+
|
1127 |
+
opts.save(shared.config_filename)
|
1128 |
+
|
1129 |
+
return f'{changed} settings changed.', opts.dumpjson()
|
1130 |
+
|
1131 |
+
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
1132 |
+
settings_submit = gr.Button(value="Apply settings", variant='primary')
|
1133 |
+
result = gr.HTML()
|
1134 |
+
|
1135 |
+
settings_cols = 3
|
1136 |
+
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
|
1137 |
+
|
1138 |
+
cols_displayed = 0
|
1139 |
+
items_displayed = 0
|
1140 |
+
previous_section = None
|
1141 |
+
column = None
|
1142 |
+
with gr.Row(elem_id="settings").style(equal_height=False):
|
1143 |
+
for i, (k, item) in enumerate(opts.data_labels.items()):
|
1144 |
+
|
1145 |
+
if previous_section != item.section:
|
1146 |
+
if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
|
1147 |
+
if column is not None:
|
1148 |
+
column.__exit__()
|
1149 |
+
|
1150 |
+
column = gr.Column(variant='panel')
|
1151 |
+
column.__enter__()
|
1152 |
+
|
1153 |
+
items_displayed = 0
|
1154 |
+
cols_displayed += 1
|
1155 |
+
|
1156 |
+
previous_section = item.section
|
1157 |
+
|
1158 |
+
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
|
1159 |
+
|
1160 |
+
component = create_setting_component(k)
|
1161 |
+
component_dict[k] = component
|
1162 |
+
components.append(component)
|
1163 |
+
items_displayed += 1
|
1164 |
+
|
1165 |
+
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
1166 |
+
request_notifications.click(
|
1167 |
+
fn=lambda: None,
|
1168 |
+
inputs=[],
|
1169 |
+
outputs=[],
|
1170 |
+
_js='function(){}'
|
1171 |
+
)
|
1172 |
+
|
1173 |
+
with gr.Row():
|
1174 |
+
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
|
1175 |
+
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
|
1176 |
+
|
1177 |
+
|
1178 |
+
def reload_scripts():
|
1179 |
+
modules.scripts.reload_script_body_only()
|
1180 |
+
|
1181 |
+
reload_script_bodies.click(
|
1182 |
+
fn=reload_scripts,
|
1183 |
+
inputs=[],
|
1184 |
+
outputs=[],
|
1185 |
+
_js='function(){}'
|
1186 |
+
)
|
1187 |
+
|
1188 |
+
def request_restart():
|
1189 |
+
shared.state.interrupt()
|
1190 |
+
settings_interface.gradio_ref.do_restart = True
|
1191 |
+
|
1192 |
+
restart_gradio.click(
|
1193 |
+
fn=request_restart,
|
1194 |
+
inputs=[],
|
1195 |
+
outputs=[],
|
1196 |
+
_js='function(){restart_reload()}'
|
1197 |
+
)
|
1198 |
+
|
1199 |
+
if column is not None:
|
1200 |
+
column.__exit__()
|
1201 |
+
|
1202 |
+
interfaces = [
|
1203 |
+
(txt2img_interface, "txt2img", "txt2img"),
|
1204 |
+
(img2img_interface, "img2img", "img2img"),
|
1205 |
+
(extras_interface, "Extras", "extras"),
|
1206 |
+
(pnginfo_interface, "PNG Info", "pnginfo"),
|
1207 |
+
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
1208 |
+
(textual_inversion_interface, "Textual inversion", "ti"),
|
1209 |
+
(settings_interface, "Settings", "settings"),
|
1210 |
+
]
|
1211 |
+
|
1212 |
+
with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
|
1213 |
+
css = file.read()
|
1214 |
+
|
1215 |
+
if os.path.exists(os.path.join(script_path, "user.css")):
|
1216 |
+
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
|
1217 |
+
usercss = file.read()
|
1218 |
+
css += usercss
|
1219 |
+
|
1220 |
+
if not cmd_opts.no_progressbar_hiding:
|
1221 |
+
css += css_hide_progressbar
|
1222 |
+
|
1223 |
+
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
|
1224 |
+
|
1225 |
+
settings_interface.gradio_ref = demo
|
1226 |
+
|
1227 |
+
with gr.Tabs() as tabs:
|
1228 |
+
for interface, label, ifid in interfaces:
|
1229 |
+
with gr.TabItem(label, id=ifid):
|
1230 |
+
interface.render()
|
1231 |
+
|
1232 |
+
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
1233 |
+
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
1234 |
+
|
1235 |
+
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
1236 |
+
settings_submit.click(
|
1237 |
+
fn=run_settings,
|
1238 |
+
inputs=components,
|
1239 |
+
outputs=[result, text_settings],
|
1240 |
+
)
|
1241 |
+
|
1242 |
+
def modelmerger(*args):
|
1243 |
+
try:
|
1244 |
+
results = modules.extras.run_modelmerger(*args)
|
1245 |
+
except Exception as e:
|
1246 |
+
print("Error loading/saving model file:", file=sys.stderr)
|
1247 |
+
print(traceback.format_exc(), file=sys.stderr)
|
1248 |
+
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
1249 |
+
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
1250 |
+
return results
|
1251 |
+
|
1252 |
+
modelmerger_merge.click(
|
1253 |
+
fn=modelmerger,
|
1254 |
+
inputs=[
|
1255 |
+
primary_model_name,
|
1256 |
+
secondary_model_name,
|
1257 |
+
interp_method,
|
1258 |
+
interp_amount,
|
1259 |
+
save_as_half,
|
1260 |
+
custom_name,
|
1261 |
+
],
|
1262 |
+
outputs=[
|
1263 |
+
submit_result,
|
1264 |
+
primary_model_name,
|
1265 |
+
secondary_model_name,
|
1266 |
+
component_dict['sd_model_checkpoint'],
|
1267 |
+
]
|
1268 |
+
)
|
1269 |
+
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
|
1270 |
+
txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
|
1271 |
+
img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
|
1272 |
+
send_to_img2img.click(
|
1273 |
+
fn=lambda img, *args: (image_from_url_text(img),*args),
|
1274 |
+
_js="(gallery, ...args) => [extract_image_from_gallery_img2img(gallery), ...args]",
|
1275 |
+
inputs=[txt2img_gallery] + txt2img_fields,
|
1276 |
+
outputs=[init_img] + img2img_fields,
|
1277 |
+
)
|
1278 |
+
|
1279 |
+
send_to_inpaint.click(
|
1280 |
+
fn=lambda x, *args: (image_from_url_text(x), *args),
|
1281 |
+
_js="(gallery, ...args) => [extract_image_from_gallery_inpaint(gallery), ...args]",
|
1282 |
+
inputs=[txt2img_gallery] + txt2img_fields,
|
1283 |
+
outputs=[init_img_with_mask] + img2img_fields,
|
1284 |
+
)
|
1285 |
+
|
1286 |
+
img2img_send_to_img2img.click(
|
1287 |
+
fn=lambda x: image_from_url_text(x),
|
1288 |
+
_js="extract_image_from_gallery_img2img",
|
1289 |
+
inputs=[img2img_gallery],
|
1290 |
+
outputs=[init_img],
|
1291 |
+
)
|
1292 |
+
|
1293 |
+
img2img_send_to_inpaint.click(
|
1294 |
+
fn=lambda x: image_from_url_text(x),
|
1295 |
+
_js="extract_image_from_gallery_inpaint",
|
1296 |
+
inputs=[img2img_gallery],
|
1297 |
+
outputs=[init_img_with_mask],
|
1298 |
+
)
|
1299 |
+
|
1300 |
+
send_to_extras.click(
|
1301 |
+
fn=lambda x: image_from_url_text(x),
|
1302 |
+
_js="extract_image_from_gallery_extras",
|
1303 |
+
inputs=[txt2img_gallery],
|
1304 |
+
outputs=[extras_image],
|
1305 |
+
)
|
1306 |
+
|
1307 |
+
open_txt2img_folder.click(
|
1308 |
+
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
|
1309 |
+
inputs=[],
|
1310 |
+
outputs=[],
|
1311 |
+
)
|
1312 |
+
|
1313 |
+
open_img2img_folder.click(
|
1314 |
+
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
|
1315 |
+
inputs=[],
|
1316 |
+
outputs=[],
|
1317 |
+
)
|
1318 |
+
|
1319 |
+
open_extras_folder.click(
|
1320 |
+
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
|
1321 |
+
inputs=[],
|
1322 |
+
outputs=[],
|
1323 |
+
)
|
1324 |
+
|
1325 |
+
img2img_send_to_extras.click(
|
1326 |
+
fn=lambda x: image_from_url_text(x),
|
1327 |
+
_js="extract_image_from_gallery_extras",
|
1328 |
+
inputs=[img2img_gallery],
|
1329 |
+
outputs=[extras_image],
|
1330 |
+
)
|
1331 |
+
|
1332 |
+
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields, generation_info, 'switch_to_txt2img')
|
1333 |
+
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields, generation_info, 'switch_to_img2img_img2img')
|
1334 |
+
|
1335 |
+
ui_config_file = cmd_opts.ui_config_file
|
1336 |
+
ui_settings = {}
|
1337 |
+
settings_count = len(ui_settings)
|
1338 |
+
error_loading = False
|
1339 |
+
|
1340 |
+
try:
|
1341 |
+
if os.path.exists(ui_config_file):
|
1342 |
+
with open(ui_config_file, "r", encoding="utf8") as file:
|
1343 |
+
ui_settings = json.load(file)
|
1344 |
+
except Exception:
|
1345 |
+
error_loading = True
|
1346 |
+
print("Error loading settings:", file=sys.stderr)
|
1347 |
+
print(traceback.format_exc(), file=sys.stderr)
|
1348 |
+
|
1349 |
+
def loadsave(path, x):
|
1350 |
+
def apply_field(obj, field, condition=None):
|
1351 |
+
key = path + "/" + field
|
1352 |
+
|
1353 |
+
if getattr(obj,'custom_script_source',None) is not None:
|
1354 |
+
key = 'customscript/' + obj.custom_script_source + '/' + key
|
1355 |
+
|
1356 |
+
if getattr(obj, 'do_not_save_to_config', False):
|
1357 |
+
return
|
1358 |
+
|
1359 |
+
saved_value = ui_settings.get(key, None)
|
1360 |
+
if saved_value is None:
|
1361 |
+
ui_settings[key] = getattr(obj, field)
|
1362 |
+
elif condition is None or condition(saved_value):
|
1363 |
+
setattr(obj, field, saved_value)
|
1364 |
+
|
1365 |
+
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
|
1366 |
+
apply_field(x, 'visible')
|
1367 |
+
|
1368 |
+
if type(x) == gr.Slider:
|
1369 |
+
apply_field(x, 'value')
|
1370 |
+
apply_field(x, 'minimum')
|
1371 |
+
apply_field(x, 'maximum')
|
1372 |
+
apply_field(x, 'step')
|
1373 |
+
|
1374 |
+
if type(x) == gr.Radio:
|
1375 |
+
apply_field(x, 'value', lambda val: val in x.choices)
|
1376 |
+
|
1377 |
+
if type(x) == gr.Checkbox:
|
1378 |
+
apply_field(x, 'value')
|
1379 |
+
|
1380 |
+
if type(x) == gr.Textbox:
|
1381 |
+
apply_field(x, 'value')
|
1382 |
+
|
1383 |
+
if type(x) == gr.Number:
|
1384 |
+
apply_field(x, 'value')
|
1385 |
+
|
1386 |
+
visit(txt2img_interface, loadsave, "txt2img")
|
1387 |
+
visit(img2img_interface, loadsave, "img2img")
|
1388 |
+
visit(extras_interface, loadsave, "extras")
|
1389 |
+
|
1390 |
+
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
1391 |
+
with open(ui_config_file, "w", encoding="utf8") as file:
|
1392 |
+
json.dump(ui_settings, file, indent=4)
|
1393 |
+
|
1394 |
+
return demo
|
1395 |
+
|
1396 |
+
|
1397 |
+
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
|
1398 |
+
javascript = f'<script>{jsfile.read()}</script>'
|
1399 |
+
|
1400 |
+
jsdir = os.path.join(script_path, "javascript")
|
1401 |
+
for filename in sorted(os.listdir(jsdir)):
|
1402 |
+
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
|
1403 |
+
javascript += f"\n<script>{jsfile.read()}</script>"
|
1404 |
+
|
1405 |
+
|
1406 |
+
if 'gradio_routes_templates_response' not in globals():
|
1407 |
+
def template_response(*args, **kwargs):
|
1408 |
+
res = gradio_routes_templates_response(*args, **kwargs)
|
1409 |
+
res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
|
1410 |
+
res.init_headers()
|
1411 |
+
return res
|
1412 |
+
|
1413 |
+
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
|
1414 |
+
gradio.routes.templates.TemplateResponse = template_response
|
modules/upscaler.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import abstractmethod
|
3 |
+
|
4 |
+
import PIL
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
import modules.shared
|
10 |
+
from modules import modelloader, shared
|
11 |
+
|
12 |
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
13 |
+
from modules.paths import models_path
|
14 |
+
|
15 |
+
|
16 |
+
class Upscaler:
|
17 |
+
name = None
|
18 |
+
model_path = None
|
19 |
+
model_name = None
|
20 |
+
model_url = None
|
21 |
+
enable = True
|
22 |
+
filter = None
|
23 |
+
model = None
|
24 |
+
user_path = None
|
25 |
+
scalers: []
|
26 |
+
tile = True
|
27 |
+
|
28 |
+
def __init__(self, create_dirs=False):
|
29 |
+
self.mod_pad_h = None
|
30 |
+
self.tile_size = modules.shared.opts.ESRGAN_tile
|
31 |
+
self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
|
32 |
+
self.device = modules.shared.device
|
33 |
+
self.img = None
|
34 |
+
self.output = None
|
35 |
+
self.scale = 1
|
36 |
+
self.half = not modules.shared.cmd_opts.no_half
|
37 |
+
self.pre_pad = 0
|
38 |
+
self.mod_scale = None
|
39 |
+
if self.name is not None and create_dirs:
|
40 |
+
self.model_path = os.path.join(models_path, self.name)
|
41 |
+
if not os.path.exists(self.model_path):
|
42 |
+
os.makedirs(self.model_path)
|
43 |
+
|
44 |
+
try:
|
45 |
+
import cv2
|
46 |
+
self.can_tile = True
|
47 |
+
except:
|
48 |
+
pass
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def do_upscale(self, img: PIL.Image, selected_model: str):
|
52 |
+
return img
|
53 |
+
|
54 |
+
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
|
55 |
+
self.scale = scale
|
56 |
+
dest_w = img.width * scale
|
57 |
+
dest_h = img.height * scale
|
58 |
+
for i in range(3):
|
59 |
+
if img.width >= dest_w and img.height >= dest_h:
|
60 |
+
break
|
61 |
+
img = self.do_upscale(img, selected_model)
|
62 |
+
if img.width != dest_w or img.height != dest_h:
|
63 |
+
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
|
64 |
+
|
65 |
+
return img
|
66 |
+
|
67 |
+
@abstractmethod
|
68 |
+
def load_model(self, path: str):
|
69 |
+
pass
|
70 |
+
|
71 |
+
def find_models(self, ext_filter=None) -> list:
|
72 |
+
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
|
73 |
+
|
74 |
+
def update_status(self, prompt):
|
75 |
+
print(f"\nextras: {prompt}", file=shared.progress_print_out)
|
76 |
+
|
77 |
+
|
78 |
+
class UpscalerData:
|
79 |
+
name = None
|
80 |
+
data_path = None
|
81 |
+
scale: int = 4
|
82 |
+
scaler: Upscaler = None
|
83 |
+
model: None
|
84 |
+
|
85 |
+
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
|
86 |
+
self.name = name
|
87 |
+
self.data_path = path
|
88 |
+
self.scaler = upscaler
|
89 |
+
self.scale = scale
|
90 |
+
self.model = model
|
91 |
+
|
92 |
+
|
93 |
+
class UpscalerNone(Upscaler):
|
94 |
+
name = "None"
|
95 |
+
scalers = []
|
96 |
+
|
97 |
+
def load_model(self, path):
|
98 |
+
pass
|
99 |
+
|
100 |
+
def do_upscale(self, img, selected_model=None):
|
101 |
+
return img
|
102 |
+
|
103 |
+
def __init__(self, dirname=None):
|
104 |
+
super().__init__(False)
|
105 |
+
self.scalers = [UpscalerData("None", None, self)]
|
106 |
+
|
107 |
+
|
108 |
+
class UpscalerLanczos(Upscaler):
|
109 |
+
scalers = []
|
110 |
+
|
111 |
+
def do_upscale(self, img, selected_model=None):
|
112 |
+
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
|
113 |
+
|
114 |
+
def load_model(self, _):
|
115 |
+
pass
|
116 |
+
|
117 |
+
def __init__(self, dirname=None):
|
118 |
+
super().__init__(False)
|
119 |
+
self.name = "Lanczos"
|
120 |
+
self.scalers = [UpscalerData("Lanczos", None, self)]
|
121 |
+
|