cymic commited on
Commit
e9ac57f
1 Parent(s): 40794ad

Upload 41 files

Browse files
modules/artists.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import csv
3
+ from collections import namedtuple
4
+
5
+ Artist = namedtuple("Artist", ['name', 'weight', 'category'])
6
+
7
+
8
+ class ArtistsDatabase:
9
+ def __init__(self, filename):
10
+ self.cats = set()
11
+ self.artists = []
12
+
13
+ if not os.path.exists(filename):
14
+ return
15
+
16
+ with open(filename, "r", newline='', encoding="utf8") as file:
17
+ reader = csv.DictReader(file)
18
+
19
+ for row in reader:
20
+ artist = Artist(row["artist"], float(row["score"]), row["category"])
21
+ self.artists.append(artist)
22
+ self.cats.add(artist.category)
23
+
24
+ def categories(self):
25
+ return sorted(self.cats)
modules/bsrgan_model.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import sys
3
+ import traceback
4
+
5
+ import PIL.Image
6
+ import numpy as np
7
+ import torch
8
+ from basicsr.utils.download_util import load_file_from_url
9
+
10
+ import modules.upscaler
11
+ from modules import devices, modelloader
12
+ from modules.bsrgan_model_arch import RRDBNet
13
+ from modules.paths import models_path
14
+
15
+
16
+ class UpscalerBSRGAN(modules.upscaler.Upscaler):
17
+ def __init__(self, dirname):
18
+ self.name = "BSRGAN"
19
+ self.model_path = os.path.join(models_path, self.name)
20
+ self.model_name = "BSRGAN 4x"
21
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
22
+ self.user_path = dirname
23
+ super().__init__()
24
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
25
+ scalers = []
26
+ if len(model_paths) == 0:
27
+ scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
28
+ scalers.append(scaler_data)
29
+ for file in model_paths:
30
+ if "http" in file:
31
+ name = self.model_name
32
+ else:
33
+ name = modelloader.friendly_name(file)
34
+ try:
35
+ scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
36
+ scalers.append(scaler_data)
37
+ except Exception:
38
+ print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
39
+ print(traceback.format_exc(), file=sys.stderr)
40
+ self.scalers = scalers
41
+
42
+ def do_upscale(self, img: PIL.Image, selected_file):
43
+ torch.cuda.empty_cache()
44
+ model = self.load_model(selected_file)
45
+ if model is None:
46
+ return img
47
+ model.to(devices.device_bsrgan)
48
+ torch.cuda.empty_cache()
49
+ img = np.array(img)
50
+ img = img[:, :, ::-1]
51
+ img = np.moveaxis(img, 2, 0) / 255
52
+ img = torch.from_numpy(img).float()
53
+ img = img.unsqueeze(0).to(devices.device_bsrgan)
54
+ with torch.no_grad():
55
+ output = model(img)
56
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
57
+ output = 255. * np.moveaxis(output, 0, 2)
58
+ output = output.astype(np.uint8)
59
+ output = output[:, :, ::-1]
60
+ torch.cuda.empty_cache()
61
+ return PIL.Image.fromarray(output, 'RGB')
62
+
63
+ def load_model(self, path: str):
64
+ if "http" in path:
65
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
66
+ progress=True)
67
+ else:
68
+ filename = path
69
+ if not os.path.exists(filename) or filename is None:
70
+ print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
71
+ return None
72
+ model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
73
+ model.load_state_dict(torch.load(filename), strict=True)
74
+ model.eval()
75
+ for k, v in model.named_parameters():
76
+ v.requires_grad = False
77
+ return model
78
+
modules/bsrgan_model_arch.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.init as init
6
+
7
+
8
+ def initialize_weights(net_l, scale=1):
9
+ if not isinstance(net_l, list):
10
+ net_l = [net_l]
11
+ for net in net_l:
12
+ for m in net.modules():
13
+ if isinstance(m, nn.Conv2d):
14
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
15
+ m.weight.data *= scale # for residual block
16
+ if m.bias is not None:
17
+ m.bias.data.zero_()
18
+ elif isinstance(m, nn.Linear):
19
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
20
+ m.weight.data *= scale
21
+ if m.bias is not None:
22
+ m.bias.data.zero_()
23
+ elif isinstance(m, nn.BatchNorm2d):
24
+ init.constant_(m.weight, 1)
25
+ init.constant_(m.bias.data, 0.0)
26
+
27
+
28
+ def make_layer(block, n_layers):
29
+ layers = []
30
+ for _ in range(n_layers):
31
+ layers.append(block())
32
+ return nn.Sequential(*layers)
33
+
34
+
35
+ class ResidualDenseBlock_5C(nn.Module):
36
+ def __init__(self, nf=64, gc=32, bias=True):
37
+ super(ResidualDenseBlock_5C, self).__init__()
38
+ # gc: growth channel, i.e. intermediate channels
39
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
40
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
41
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
42
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
43
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
44
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
45
+
46
+ # initialization
47
+ initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
48
+
49
+ def forward(self, x):
50
+ x1 = self.lrelu(self.conv1(x))
51
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
52
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
53
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
54
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
55
+ return x5 * 0.2 + x
56
+
57
+
58
+ class RRDB(nn.Module):
59
+ '''Residual in Residual Dense Block'''
60
+
61
+ def __init__(self, nf, gc=32):
62
+ super(RRDB, self).__init__()
63
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
64
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
65
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
66
+
67
+ def forward(self, x):
68
+ out = self.RDB1(x)
69
+ out = self.RDB2(out)
70
+ out = self.RDB3(out)
71
+ return out * 0.2 + x
72
+
73
+
74
+ class RRDBNet(nn.Module):
75
+ def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
76
+ super(RRDBNet, self).__init__()
77
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
78
+ self.sf = sf
79
+
80
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
81
+ self.RRDB_trunk = make_layer(RRDB_block_f, nb)
82
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
83
+ #### upsampling
84
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
85
+ if self.sf==4:
86
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
87
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
88
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
89
+
90
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
91
+
92
+ def forward(self, x):
93
+ fea = self.conv_first(x)
94
+ trunk = self.trunk_conv(self.RRDB_trunk(fea))
95
+ fea = fea + trunk
96
+
97
+ fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
98
+ if self.sf==4:
99
+ fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
100
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
101
+
102
+ return out
modules/codeformer_model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import cv2
6
+ import torch
7
+
8
+ import modules.face_restoration
9
+ import modules.shared
10
+ from modules import shared, devices, modelloader
11
+ from modules.paths import script_path, models_path
12
+
13
+ # codeformer people made a choice to include modified basicsr library to their project which makes
14
+ # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
15
+ # I am making a choice to include some files from codeformer to work around this issue.
16
+ model_dir = "Codeformer"
17
+ model_path = os.path.join(models_path, model_dir)
18
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
19
+
20
+ have_codeformer = False
21
+ codeformer = None
22
+
23
+
24
+ def setup_model(dirname):
25
+ global model_path
26
+ if not os.path.exists(model_path):
27
+ os.makedirs(model_path)
28
+
29
+ path = modules.paths.paths.get("CodeFormer", None)
30
+ if path is None:
31
+ return
32
+
33
+ try:
34
+ from torchvision.transforms.functional import normalize
35
+ from modules.codeformer.codeformer_arch import CodeFormer
36
+ from basicsr.utils.download_util import load_file_from_url
37
+ from basicsr.utils import imwrite, img2tensor, tensor2img
38
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
39
+ from modules.shared import cmd_opts
40
+
41
+ net_class = CodeFormer
42
+
43
+ class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
44
+ def name(self):
45
+ return "CodeFormer"
46
+
47
+ def __init__(self, dirname):
48
+ self.net = None
49
+ self.face_helper = None
50
+ self.cmd_dir = dirname
51
+
52
+ def create_models(self):
53
+
54
+ if self.net is not None and self.face_helper is not None:
55
+ self.net.to(devices.device_codeformer)
56
+ return self.net, self.face_helper
57
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
58
+ if len(model_paths) != 0:
59
+ ckpt_path = model_paths[0]
60
+ else:
61
+ print("Unable to load codeformer model.")
62
+ return None, None
63
+ net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
64
+ checkpoint = torch.load(ckpt_path)['params_ema']
65
+ net.load_state_dict(checkpoint)
66
+ net.eval()
67
+
68
+ face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
69
+
70
+ self.net = net
71
+ self.face_helper = face_helper
72
+
73
+ return net, face_helper
74
+
75
+ def send_model_to(self, device):
76
+ self.net.to(device)
77
+ self.face_helper.face_det.to(device)
78
+ self.face_helper.face_parse.to(device)
79
+
80
+ def restore(self, np_image, w=None):
81
+ np_image = np_image[:, :, ::-1]
82
+
83
+ original_resolution = np_image.shape[0:2]
84
+
85
+ self.create_models()
86
+ if self.net is None or self.face_helper is None:
87
+ return np_image
88
+
89
+ self.send_model_to(devices.device_codeformer)
90
+
91
+ self.face_helper.clean_all()
92
+ self.face_helper.read_image(np_image)
93
+ self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
94
+ self.face_helper.align_warp_face()
95
+
96
+ for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
97
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
98
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
99
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
100
+
101
+ try:
102
+ with torch.no_grad():
103
+ output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
104
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
105
+ del output
106
+ torch.cuda.empty_cache()
107
+ except Exception as error:
108
+ print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
109
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
110
+
111
+ restored_face = restored_face.astype('uint8')
112
+ self.face_helper.add_restored_face(restored_face)
113
+
114
+ self.face_helper.get_inverse_affine(None)
115
+
116
+ restored_img = self.face_helper.paste_faces_to_input_image()
117
+ restored_img = restored_img[:, :, ::-1]
118
+
119
+ if original_resolution != restored_img.shape[0:2]:
120
+ restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
121
+
122
+ self.face_helper.clean_all()
123
+
124
+ if shared.opts.face_restoration_unload:
125
+ self.send_model_to(devices.cpu)
126
+
127
+ return restored_img
128
+
129
+ global have_codeformer
130
+ have_codeformer = True
131
+
132
+ global codeformer
133
+ codeformer = FaceRestorerCodeFormer(dirname)
134
+ shared.face_restorers.append(codeformer)
135
+
136
+ except Exception:
137
+ print("Error setting up CodeFormer:", file=sys.stderr)
138
+ print(traceback.format_exc(), file=sys.stderr)
139
+
140
+ # sys.path = stored_sys_path
modules/devices.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+
3
+ import torch
4
+
5
+ from modules import errors
6
+
7
+ # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
8
+ has_mps = getattr(torch, 'has_mps', False)
9
+
10
+ cpu = torch.device("cpu")
11
+
12
+
13
+ def get_optimal_device():
14
+ if torch.cuda.is_available():
15
+ return torch.device("cuda")
16
+
17
+ if has_mps:
18
+ return torch.device("mps")
19
+
20
+ return cpu
21
+
22
+
23
+ def torch_gc():
24
+ if torch.cuda.is_available():
25
+ torch.cuda.empty_cache()
26
+ torch.cuda.ipc_collect()
27
+
28
+
29
+ def enable_tf32():
30
+ if torch.cuda.is_available():
31
+ torch.backends.cuda.matmul.allow_tf32 = True
32
+ torch.backends.cudnn.allow_tf32 = True
33
+
34
+
35
+ errors.run(enable_tf32, "Enabling TF32")
36
+
37
+ device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
38
+ dtype = torch.float16
39
+
40
+ def randn(seed, shape):
41
+ # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
42
+ if device.type == 'mps':
43
+ generator = torch.Generator(device=cpu)
44
+ generator.manual_seed(seed)
45
+ noise = torch.randn(shape, generator=generator, device=cpu).to(device)
46
+ return noise
47
+
48
+ torch.manual_seed(seed)
49
+ return torch.randn(shape, device=device)
50
+
51
+
52
+ def randn_without_seed(shape):
53
+ # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
54
+ if device.type == 'mps':
55
+ generator = torch.Generator(device=cpu)
56
+ noise = torch.randn(shape, generator=generator, device=cpu).to(device)
57
+ return noise
58
+
59
+ return torch.randn(shape, device=device)
60
+
61
+
62
+ def autocast():
63
+ from modules import shared
64
+
65
+ if dtype == torch.float32 or shared.cmd_opts.precision == "full":
66
+ return contextlib.nullcontext()
67
+
68
+ return torch.autocast("cuda")
modules/errors.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import traceback
3
+
4
+
5
+ def run(code, task):
6
+ try:
7
+ code()
8
+ except Exception as e:
9
+ print(f"{task}: {type(e).__name__}", file=sys.stderr)
10
+ print(traceback.format_exc(), file=sys.stderr)
modules/esrgam_model_arch.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file is taken from https://github.com/xinntao/ESRGAN
2
+
3
+ import functools
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def make_layer(block, n_layers):
10
+ layers = []
11
+ for _ in range(n_layers):
12
+ layers.append(block())
13
+ return nn.Sequential(*layers)
14
+
15
+
16
+ class ResidualDenseBlock_5C(nn.Module):
17
+ def __init__(self, nf=64, gc=32, bias=True):
18
+ super(ResidualDenseBlock_5C, self).__init__()
19
+ # gc: growth channel, i.e. intermediate channels
20
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
21
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
22
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
23
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
24
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
25
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
26
+
27
+ # initialization
28
+ # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
29
+
30
+ def forward(self, x):
31
+ x1 = self.lrelu(self.conv1(x))
32
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
33
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
34
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
35
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
36
+ return x5 * 0.2 + x
37
+
38
+
39
+ class RRDB(nn.Module):
40
+ '''Residual in Residual Dense Block'''
41
+
42
+ def __init__(self, nf, gc=32):
43
+ super(RRDB, self).__init__()
44
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
45
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
46
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
47
+
48
+ def forward(self, x):
49
+ out = self.RDB1(x)
50
+ out = self.RDB2(out)
51
+ out = self.RDB3(out)
52
+ return out * 0.2 + x
53
+
54
+
55
+ class RRDBNet(nn.Module):
56
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32):
57
+ super(RRDBNet, self).__init__()
58
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
59
+
60
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
61
+ self.RRDB_trunk = make_layer(RRDB_block_f, nb)
62
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
63
+ #### upsampling
64
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
65
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
66
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
67
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
68
+
69
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
70
+
71
+ def forward(self, x):
72
+ fea = self.conv_first(x)
73
+ trunk = self.trunk_conv(self.RRDB_trunk(fea))
74
+ fea = fea + trunk
75
+
76
+ fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
77
+ fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
78
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
79
+
80
+ return out
modules/esrgan_model.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from basicsr.utils.download_util import load_file_from_url
7
+
8
+ import modules.esrgam_model_arch as arch
9
+ from modules import shared, modelloader, images, devices
10
+ from modules.paths import models_path
11
+ from modules.upscaler import Upscaler, UpscalerData
12
+ from modules.shared import opts
13
+
14
+
15
+ def fix_model_layers(crt_model, pretrained_net):
16
+ # this code is adapted from https://github.com/xinntao/ESRGAN
17
+ if 'conv_first.weight' in pretrained_net:
18
+ return pretrained_net
19
+
20
+ if 'model.0.weight' not in pretrained_net:
21
+ is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
22
+ if is_realesrgan:
23
+ raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
24
+ else:
25
+ raise Exception("The file is not a ESRGAN model.")
26
+
27
+ crt_net = crt_model.state_dict()
28
+ load_net_clean = {}
29
+ for k, v in pretrained_net.items():
30
+ if k.startswith('module.'):
31
+ load_net_clean[k[7:]] = v
32
+ else:
33
+ load_net_clean[k] = v
34
+ pretrained_net = load_net_clean
35
+
36
+ tbd = []
37
+ for k, v in crt_net.items():
38
+ tbd.append(k)
39
+
40
+ # directly copy
41
+ for k, v in crt_net.items():
42
+ if k in pretrained_net and pretrained_net[k].size() == v.size():
43
+ crt_net[k] = pretrained_net[k]
44
+ tbd.remove(k)
45
+
46
+ crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
47
+ crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
48
+
49
+ for k in tbd.copy():
50
+ if 'RDB' in k:
51
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
52
+ if '.weight' in k:
53
+ ori_k = ori_k.replace('.weight', '.0.weight')
54
+ elif '.bias' in k:
55
+ ori_k = ori_k.replace('.bias', '.0.bias')
56
+ crt_net[k] = pretrained_net[ori_k]
57
+ tbd.remove(k)
58
+
59
+ crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
60
+ crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
61
+ crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
62
+ crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
63
+ crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
64
+ crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
65
+ crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
66
+ crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
67
+ crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
68
+ crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
69
+
70
+ return crt_net
71
+
72
+ class UpscalerESRGAN(Upscaler):
73
+ def __init__(self, dirname):
74
+ self.name = "ESRGAN"
75
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
76
+ self.model_name = "ESRGAN_4x"
77
+ self.scalers = []
78
+ self.user_path = dirname
79
+ self.model_path = os.path.join(models_path, self.name)
80
+ super().__init__()
81
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
82
+ scalers = []
83
+ if len(model_paths) == 0:
84
+ scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
85
+ scalers.append(scaler_data)
86
+ for file in model_paths:
87
+ if "http" in file:
88
+ name = self.model_name
89
+ else:
90
+ name = modelloader.friendly_name(file)
91
+
92
+ scaler_data = UpscalerData(name, file, self, 4)
93
+ self.scalers.append(scaler_data)
94
+
95
+ def do_upscale(self, img, selected_model):
96
+ model = self.load_model(selected_model)
97
+ if model is None:
98
+ return img
99
+ model.to(devices.device_esrgan)
100
+ img = esrgan_upscale(model, img)
101
+ return img
102
+
103
+ def load_model(self, path: str):
104
+ if "http" in path:
105
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
106
+ file_name="%s.pth" % self.model_name,
107
+ progress=True)
108
+ else:
109
+ filename = path
110
+ if not os.path.exists(filename) or filename is None:
111
+ print("Unable to load %s from %s" % (self.model_path, filename))
112
+ return None
113
+
114
+ pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None)
115
+ crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
116
+
117
+ pretrained_net = fix_model_layers(crt_model, pretrained_net)
118
+ crt_model.load_state_dict(pretrained_net)
119
+ crt_model.eval()
120
+
121
+ return crt_model
122
+
123
+
124
+ def upscale_without_tiling(model, img):
125
+ img = np.array(img)
126
+ img = img[:, :, ::-1]
127
+ img = np.moveaxis(img, 2, 0) / 255
128
+ img = torch.from_numpy(img).float()
129
+ img = img.unsqueeze(0).to(devices.device_esrgan)
130
+ with torch.no_grad():
131
+ output = model(img)
132
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
133
+ output = 255. * np.moveaxis(output, 0, 2)
134
+ output = output.astype(np.uint8)
135
+ output = output[:, :, ::-1]
136
+ return Image.fromarray(output, 'RGB')
137
+
138
+
139
+ def esrgan_upscale(model, img):
140
+ if opts.ESRGAN_tile == 0:
141
+ return upscale_without_tiling(model, img)
142
+
143
+ grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
144
+ newtiles = []
145
+ scale_factor = 1
146
+
147
+ for y, h, row in grid.tiles:
148
+ newrow = []
149
+ for tiledata in row:
150
+ x, w, tile = tiledata
151
+
152
+ output = upscale_without_tiling(model, tile)
153
+ scale_factor = output.width // tile.width
154
+
155
+ newrow.append([x * scale_factor, w * scale_factor, output])
156
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
157
+
158
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
159
+ output = images.combine_grid(newgrid)
160
+ return output
modules/extras.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ import torch
7
+ import tqdm
8
+
9
+ from modules import processing, shared, images, devices, sd_models
10
+ from modules.shared import opts
11
+ import modules.gfpgan_model
12
+ from modules.ui import plaintext_to_html
13
+ import modules.codeformer_model
14
+ import piexif
15
+ import piexif.helper
16
+ import gradio as gr
17
+
18
+
19
+ cached_images = {}
20
+
21
+
22
+ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
23
+ devices.torch_gc()
24
+
25
+ imageArr = []
26
+ # Also keep track of original file names
27
+ imageNameArr = []
28
+
29
+ if extras_mode == 1:
30
+ #convert file to pillow image
31
+ for img in image_folder:
32
+ image = Image.fromarray(np.array(Image.open(img)))
33
+ imageArr.append(image)
34
+ imageNameArr.append(os.path.splitext(img.orig_name)[0])
35
+ else:
36
+ imageArr.append(image)
37
+ imageNameArr.append(None)
38
+
39
+ outpath = opts.outdir_samples or opts.outdir_extras_samples
40
+
41
+ outputs = []
42
+ for image, image_name in zip(imageArr, imageNameArr):
43
+ if image is None:
44
+ return outputs, "Please select an input image.", ''
45
+ existing_pnginfo = image.info or {}
46
+
47
+ image = image.convert("RGB")
48
+ info = ""
49
+
50
+ if gfpgan_visibility > 0:
51
+ restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
52
+ res = Image.fromarray(restored_img)
53
+
54
+ if gfpgan_visibility < 1.0:
55
+ res = Image.blend(image, res, gfpgan_visibility)
56
+
57
+ info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
58
+ image = res
59
+
60
+ if codeformer_visibility > 0:
61
+ restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
62
+ res = Image.fromarray(restored_img)
63
+
64
+ if codeformer_visibility < 1.0:
65
+ res = Image.blend(image, res, codeformer_visibility)
66
+
67
+ info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
68
+ image = res
69
+
70
+ if upscaling_resize != 1.0:
71
+ def upscale(image, scaler_index, resize):
72
+ small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
73
+ pixels = tuple(np.array(small).flatten().tolist())
74
+ key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
75
+
76
+ c = cached_images.get(key)
77
+ if c is None:
78
+ upscaler = shared.sd_upscalers[scaler_index]
79
+ c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
80
+ cached_images[key] = c
81
+
82
+ return c
83
+
84
+ info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
85
+ res = upscale(image, extras_upscaler_1, upscaling_resize)
86
+
87
+ if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
88
+ res2 = upscale(image, extras_upscaler_2, upscaling_resize)
89
+ info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
90
+ res = Image.blend(res, res2, extras_upscaler_2_visibility)
91
+
92
+ image = res
93
+
94
+ while len(cached_images) > 2:
95
+ del cached_images[next(iter(cached_images.keys()))]
96
+
97
+ images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
98
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
99
+ forced_filename=image_name if opts.use_original_name_batch else None)
100
+
101
+ outputs.append(image)
102
+
103
+ devices.torch_gc()
104
+
105
+ return outputs, plaintext_to_html(info), ''
106
+
107
+
108
+ def run_pnginfo(image):
109
+ if image is None:
110
+ return '', '', ''
111
+
112
+ items = image.info
113
+ geninfo = ''
114
+
115
+ if "exif" in image.info:
116
+ exif = piexif.load(image.info["exif"])
117
+ exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
118
+ try:
119
+ exif_comment = piexif.helper.UserComment.load(exif_comment)
120
+ except ValueError:
121
+ exif_comment = exif_comment.decode('utf8', errors="ignore")
122
+
123
+ items['exif comment'] = exif_comment
124
+ geninfo = exif_comment
125
+
126
+ for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
127
+ 'loop', 'background', 'timestamp', 'duration']:
128
+ items.pop(field, None)
129
+
130
+ geninfo = items.get('parameters', geninfo)
131
+
132
+ info = ''
133
+ for key, text in items.items():
134
+ info += f"""
135
+ <div>
136
+ <p><b>{plaintext_to_html(str(key))}</b></p>
137
+ <p>{plaintext_to_html(str(text))}</p>
138
+ </div>
139
+ """.strip()+"\n"
140
+
141
+ if len(info) == 0:
142
+ message = "Nothing found in the image."
143
+ info = f"<div><p>{message}<p></div>"
144
+
145
+ return '', geninfo, info
146
+
147
+
148
+ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
149
+ # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
150
+ def weighted_sum(theta0, theta1, alpha):
151
+ return ((1 - alpha) * theta0) + (alpha * theta1)
152
+
153
+ # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
154
+ def sigmoid(theta0, theta1, alpha):
155
+ alpha = alpha * alpha * (3 - (2 * alpha))
156
+ return theta0 + ((theta1 - theta0) * alpha)
157
+
158
+ # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
159
+ def inv_sigmoid(theta0, theta1, alpha):
160
+ import math
161
+ alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
162
+ return theta0 + ((theta1 - theta0) * alpha)
163
+
164
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
165
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
166
+
167
+ print(f"Loading {primary_model_info.filename}...")
168
+ primary_model = torch.load(primary_model_info.filename, map_location='cpu')
169
+
170
+ print(f"Loading {secondary_model_info.filename}...")
171
+ secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
172
+
173
+ theta_0 = primary_model['state_dict']
174
+ theta_1 = secondary_model['state_dict']
175
+
176
+ theta_funcs = {
177
+ "Weighted Sum": weighted_sum,
178
+ "Sigmoid": sigmoid,
179
+ "Inverse Sigmoid": inv_sigmoid,
180
+ }
181
+ theta_func = theta_funcs[interp_method]
182
+
183
+ print(f"Merging...")
184
+ for key in tqdm.tqdm(theta_0.keys()):
185
+ if 'model' in key and key in theta_1:
186
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
187
+ if save_as_half:
188
+ theta_0[key] = theta_0[key].half()
189
+
190
+ for key in theta_1.keys():
191
+ if 'model' in key and key not in theta_0:
192
+ theta_0[key] = theta_1[key]
193
+ if save_as_half:
194
+ theta_0[key] = theta_0[key].half()
195
+
196
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
197
+
198
+ filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
199
+ filename = filename if custom_name == '' else (custom_name + '.ckpt')
200
+ output_modelname = os.path.join(ckpt_dir, filename)
201
+
202
+ print(f"Saving to {output_modelname}...")
203
+ torch.save(primary_model, output_modelname)
204
+
205
+ sd_models.list_models()
206
+
207
+ print(f"Checkpoint saved.")
208
+ return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
modules/face_restoration.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import shared
2
+
3
+
4
+ class FaceRestoration:
5
+ def name(self):
6
+ return "None"
7
+
8
+ def restore(self, np_image):
9
+ return np_image
10
+
11
+
12
+ def restore_faces(np_image):
13
+ face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
14
+ if len(face_restorers) == 0:
15
+ return np_image
16
+
17
+ face_restorer = face_restorers[0]
18
+
19
+ return face_restorer.restore(np_image)
modules/generation_parameters_copypaste.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import gradio as gr
3
+
4
+ re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
5
+ re_param = re.compile(re_param_code)
6
+ re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
7
+ re_imagesize = re.compile(r"^(\d+)x(\d+)$")
8
+ type_of_gr_update = type(gr.update())
9
+
10
+
11
+ def parse_generation_parameters(x: str):
12
+ """parses generation parameters string, the one you see in text field under the picture in UI:
13
+ ```
14
+ girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
15
+ Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
16
+ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
17
+ ```
18
+
19
+ returns a dict with field values
20
+ """
21
+
22
+ res = {}
23
+
24
+ prompt = ""
25
+ negative_prompt = ""
26
+
27
+ done_with_prompt = False
28
+
29
+ *lines, lastline = x.strip().split("\n")
30
+ if not re_params.match(lastline):
31
+ lines.append(lastline)
32
+ lastline = ''
33
+
34
+ for i, line in enumerate(lines):
35
+ line = line.strip()
36
+ if line.startswith("Negative prompt:"):
37
+ done_with_prompt = True
38
+ line = line[16:].strip()
39
+
40
+ if done_with_prompt:
41
+ negative_prompt += ("" if negative_prompt == "" else "\n") + line
42
+ else:
43
+ prompt += ("" if prompt == "" else "\n") + line
44
+
45
+ if len(prompt) > 0:
46
+ res["Prompt"] = prompt
47
+
48
+ if len(negative_prompt) > 0:
49
+ res["Negative prompt"] = negative_prompt
50
+
51
+ for k, v in re_param.findall(lastline):
52
+ m = re_imagesize.match(v)
53
+ if m is not None:
54
+ res[k+"-1"] = m.group(1)
55
+ res[k+"-2"] = m.group(2)
56
+ else:
57
+ res[k] = v
58
+
59
+ return res
60
+
61
+
62
+ def connect_paste(button, paste_fields, input_comp, js=None):
63
+ def paste_func(prompt):
64
+ params = parse_generation_parameters(prompt)
65
+ res = []
66
+
67
+ for output, key in paste_fields:
68
+ if callable(key):
69
+ v = key(params)
70
+ else:
71
+ v = params.get(key, None)
72
+
73
+ if v is None:
74
+ res.append(gr.update())
75
+ elif isinstance(v, type_of_gr_update):
76
+ res.append(v)
77
+ else:
78
+ try:
79
+ valtype = type(output.value)
80
+ val = valtype(v)
81
+ res.append(gr.update(value=val))
82
+ except Exception:
83
+ res.append(gr.update())
84
+
85
+ return res
86
+
87
+ button.click(
88
+ fn=paste_func,
89
+ _js=js,
90
+ inputs=[input_comp],
91
+ outputs=[x[0] for x in paste_fields],
92
+ )
modules/gfpgan_model.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import facexlib
6
+ import gfpgan
7
+
8
+ import modules.face_restoration
9
+ from modules import shared, devices, modelloader
10
+ from modules.paths import models_path
11
+
12
+ model_dir = "GFPGAN"
13
+ user_path = None
14
+ model_path = os.path.join(models_path, model_dir)
15
+ model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
16
+ have_gfpgan = False
17
+ loaded_gfpgan_model = None
18
+
19
+
20
+ def gfpgann():
21
+ global loaded_gfpgan_model
22
+ global model_path
23
+ if loaded_gfpgan_model is not None:
24
+ loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
25
+ return loaded_gfpgan_model
26
+
27
+ if gfpgan_constructor is None:
28
+ return None
29
+
30
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
31
+ if len(models) == 1 and "http" in models[0]:
32
+ model_file = models[0]
33
+ elif len(models) != 0:
34
+ latest_file = max(models, key=os.path.getctime)
35
+ model_file = latest_file
36
+ else:
37
+ print("Unable to load gfpgan model!")
38
+ return None
39
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
40
+ loaded_gfpgan_model = model
41
+
42
+ return model
43
+
44
+
45
+ def send_model_to(model, device):
46
+ model.gfpgan.to(device)
47
+ model.face_helper.face_det.to(device)
48
+ model.face_helper.face_parse.to(device)
49
+
50
+
51
+ def gfpgan_fix_faces(np_image):
52
+ model = gfpgann()
53
+ if model is None:
54
+ return np_image
55
+
56
+ send_model_to(model, devices.device_gfpgan)
57
+
58
+ np_image_bgr = np_image[:, :, ::-1]
59
+ cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
60
+ np_image = gfpgan_output_bgr[:, :, ::-1]
61
+
62
+ model.face_helper.clean_all()
63
+
64
+ if shared.opts.face_restoration_unload:
65
+ send_model_to(model, devices.cpu)
66
+
67
+ return np_image
68
+
69
+
70
+ gfpgan_constructor = None
71
+
72
+
73
+ def setup_model(dirname):
74
+ global model_path
75
+ if not os.path.exists(model_path):
76
+ os.makedirs(model_path)
77
+
78
+ try:
79
+ from gfpgan import GFPGANer
80
+ from facexlib import detection, parsing
81
+ global user_path
82
+ global have_gfpgan
83
+ global gfpgan_constructor
84
+
85
+ load_file_from_url_orig = gfpgan.utils.load_file_from_url
86
+ facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
87
+ facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
88
+
89
+ def my_load_file_from_url(**kwargs):
90
+ return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
91
+
92
+ def facex_load_file_from_url(**kwargs):
93
+ return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
94
+
95
+ def facex_load_file_from_url2(**kwargs):
96
+ return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
97
+
98
+ gfpgan.utils.load_file_from_url = my_load_file_from_url
99
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
100
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
101
+ user_path = dirname
102
+ have_gfpgan = True
103
+ gfpgan_constructor = GFPGANer
104
+
105
+ class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
106
+ def name(self):
107
+ return "GFPGAN"
108
+
109
+ def restore(self, np_image):
110
+ return gfpgan_fix_faces(np_image)
111
+
112
+ shared.face_restorers.append(FaceRestorerGFPGAN())
113
+ except Exception:
114
+ print("Error setting up GFPGAN:", file=sys.stderr)
115
+ print(traceback.format_exc(), file=sys.stderr)
modules/hypernetwork.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import sys
4
+ import traceback
5
+
6
+ import torch
7
+
8
+ from ldm.util import default
9
+ from modules import devices, shared
10
+ import torch
11
+ from torch import einsum
12
+ from einops import rearrange, repeat
13
+
14
+
15
+ class HypernetworkModule(torch.nn.Module):
16
+ def __init__(self, dim, state_dict):
17
+ super().__init__()
18
+
19
+ self.linear1 = torch.nn.Linear(dim, dim * 2)
20
+ self.linear2 = torch.nn.Linear(dim * 2, dim)
21
+
22
+ self.load_state_dict(state_dict, strict=True)
23
+ self.to(devices.device)
24
+
25
+ def forward(self, x):
26
+ return x + (self.linear2(self.linear1(x)))
27
+
28
+
29
+ class Hypernetwork:
30
+ filename = None
31
+ name = None
32
+
33
+ def __init__(self, filename):
34
+ self.filename = filename
35
+ self.name = os.path.splitext(os.path.basename(filename))[0]
36
+ self.layers = {}
37
+
38
+ state_dict = torch.load(filename, map_location='cpu')
39
+ for size, sd in state_dict.items():
40
+ self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
41
+
42
+
43
+ def load_hypernetworks(path):
44
+ res = {}
45
+
46
+ for filename in glob.iglob(path + '**/*.pt', recursive=True):
47
+ try:
48
+ hn = Hypernetwork(filename)
49
+ res[hn.name] = hn
50
+ except Exception:
51
+ print(f"Error loading hypernetwork {filename}", file=sys.stderr)
52
+ print(traceback.format_exc(), file=sys.stderr)
53
+
54
+ return res
55
+
56
+
57
+ def attention_CrossAttention_forward(self, x, context=None, mask=None):
58
+ h = self.heads
59
+
60
+ q = self.to_q(x)
61
+ context = default(context, x)
62
+
63
+ hypernetwork = shared.selected_hypernetwork()
64
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
65
+
66
+ if hypernetwork_layers is not None:
67
+ k = self.to_k(hypernetwork_layers[0](context))
68
+ v = self.to_v(hypernetwork_layers[1](context))
69
+ else:
70
+ k = self.to_k(context)
71
+ v = self.to_v(context)
72
+
73
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
74
+
75
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
76
+
77
+ if mask is not None:
78
+ mask = rearrange(mask, 'b ... -> b (...)')
79
+ max_neg_value = -torch.finfo(sim.dtype).max
80
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
81
+ sim.masked_fill_(~mask, max_neg_value)
82
+
83
+ # attention, what we cannot get enough of
84
+ attn = sim.softmax(dim=-1)
85
+
86
+ out = einsum('b i j, b j d -> b i d', attn, v)
87
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
88
+ return self.to_out(out)
modules/images.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import math
3
+ import os
4
+ from collections import namedtuple
5
+ import re
6
+
7
+ import numpy as np
8
+ import piexif
9
+ import piexif.helper
10
+ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
11
+ from fonts.ttf import Roboto
12
+ import string
13
+
14
+ from modules import sd_samplers, shared
15
+ from modules.shared import opts, cmd_opts
16
+
17
+ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
18
+
19
+
20
+ def image_grid(imgs, batch_size=1, rows=None):
21
+ if rows is None:
22
+ if opts.n_rows > 0:
23
+ rows = opts.n_rows
24
+ elif opts.n_rows == 0:
25
+ rows = batch_size
26
+ else:
27
+ rows = math.sqrt(len(imgs))
28
+ rows = round(rows)
29
+
30
+ cols = math.ceil(len(imgs) / rows)
31
+
32
+ w, h = imgs[0].size
33
+ grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
34
+
35
+ for i, img in enumerate(imgs):
36
+ grid.paste(img, box=(i % cols * w, i // cols * h))
37
+
38
+ return grid
39
+
40
+
41
+ Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
42
+
43
+
44
+ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
45
+ w = image.width
46
+ h = image.height
47
+
48
+ non_overlap_width = tile_w - overlap
49
+ non_overlap_height = tile_h - overlap
50
+
51
+ cols = math.ceil((w - overlap) / non_overlap_width)
52
+ rows = math.ceil((h - overlap) / non_overlap_height)
53
+
54
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
55
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
56
+
57
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
58
+ for row in range(rows):
59
+ row_images = []
60
+
61
+ y = int(row * dy)
62
+
63
+ if y + tile_h >= h:
64
+ y = h - tile_h
65
+
66
+ for col in range(cols):
67
+ x = int(col * dx)
68
+
69
+ if x + tile_w >= w:
70
+ x = w - tile_w
71
+
72
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
73
+
74
+ row_images.append([x, tile_w, tile])
75
+
76
+ grid.tiles.append([y, tile_h, row_images])
77
+
78
+ return grid
79
+
80
+
81
+ def combine_grid(grid):
82
+ def make_mask_image(r):
83
+ r = r * 255 / grid.overlap
84
+ r = r.astype(np.uint8)
85
+ return Image.fromarray(r, 'L')
86
+
87
+ mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
88
+ mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
89
+
90
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
91
+ for y, h, row in grid.tiles:
92
+ combined_row = Image.new("RGB", (grid.image_w, h))
93
+ for x, w, tile in row:
94
+ if x == 0:
95
+ combined_row.paste(tile, (0, 0))
96
+ continue
97
+
98
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
99
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
100
+
101
+ if y == 0:
102
+ combined_image.paste(combined_row, (0, 0))
103
+ continue
104
+
105
+ combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
106
+ combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
107
+
108
+ return combined_image
109
+
110
+
111
+ class GridAnnotation:
112
+ def __init__(self, text='', is_active=True):
113
+ self.text = text
114
+ self.is_active = is_active
115
+ self.size = None
116
+
117
+
118
+ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
119
+ def wrap(drawing, text, font, line_length):
120
+ lines = ['']
121
+ for word in text.split():
122
+ line = f'{lines[-1]} {word}'.strip()
123
+ if drawing.textlength(line, font=font) <= line_length:
124
+ lines[-1] = line
125
+ else:
126
+ lines.append(word)
127
+ return lines
128
+
129
+ def draw_texts(drawing, draw_x, draw_y, lines):
130
+ for i, line in enumerate(lines):
131
+ drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
132
+
133
+ if not line.is_active:
134
+ drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
135
+
136
+ draw_y += line.size[1] + line_spacing
137
+
138
+ fontsize = (width + height) // 25
139
+ line_spacing = fontsize // 2
140
+
141
+ try:
142
+ fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
143
+ except Exception:
144
+ fnt = ImageFont.truetype(Roboto, fontsize)
145
+
146
+ color_active = (0, 0, 0)
147
+ color_inactive = (153, 153, 153)
148
+
149
+ pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
150
+
151
+ cols = im.width // width
152
+ rows = im.height // height
153
+
154
+ assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
155
+ assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
156
+
157
+ calc_img = Image.new("RGB", (1, 1), "white")
158
+ calc_d = ImageDraw.Draw(calc_img)
159
+
160
+ for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
161
+ items = [] + texts
162
+ texts.clear()
163
+
164
+ for line in items:
165
+ wrapped = wrap(calc_d, line.text, fnt, allowed_width)
166
+ texts += [GridAnnotation(x, line.is_active) for x in wrapped]
167
+
168
+ for line in texts:
169
+ bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
170
+ line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
171
+
172
+ hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
173
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
174
+ ver_texts]
175
+
176
+ pad_top = max(hor_text_heights) + line_spacing * 2
177
+
178
+ result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
179
+ result.paste(im, (pad_left, pad_top))
180
+
181
+ d = ImageDraw.Draw(result)
182
+
183
+ for col in range(cols):
184
+ x = pad_left + width * col + width / 2
185
+ y = pad_top / 2 - hor_text_heights[col] / 2
186
+
187
+ draw_texts(d, x, y, hor_texts[col])
188
+
189
+ for row in range(rows):
190
+ x = pad_left / 2
191
+ y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
192
+
193
+ draw_texts(d, x, y, ver_texts[row])
194
+
195
+ return result
196
+
197
+
198
+ def draw_prompt_matrix(im, width, height, all_prompts):
199
+ prompts = all_prompts[1:]
200
+ boundary = math.ceil(len(prompts) / 2)
201
+
202
+ prompts_horiz = prompts[:boundary]
203
+ prompts_vert = prompts[boundary:]
204
+
205
+ hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
206
+ ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
207
+
208
+ return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
209
+
210
+
211
+ def resize_image(resize_mode, im, width, height):
212
+ def resize(im, w, h):
213
+ if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
214
+ return im.resize((w, h), resample=LANCZOS)
215
+
216
+ scale = max(w / im.width, h / im.height)
217
+
218
+ if scale > 1.0:
219
+ upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
220
+ assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
221
+
222
+ upscaler = upscalers[0]
223
+ im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
224
+
225
+ if im.width != w or im.height != h:
226
+ im = im.resize((w, h), resample=LANCZOS)
227
+
228
+ return im
229
+
230
+ if resize_mode == 0:
231
+ res = resize(im, width, height)
232
+
233
+ elif resize_mode == 1:
234
+ ratio = width / height
235
+ src_ratio = im.width / im.height
236
+
237
+ src_w = width if ratio > src_ratio else im.width * height // im.height
238
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
239
+
240
+ resized = resize(im, src_w, src_h)
241
+ res = Image.new("RGB", (width, height))
242
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
243
+
244
+ else:
245
+ ratio = width / height
246
+ src_ratio = im.width / im.height
247
+
248
+ src_w = width if ratio < src_ratio else im.width * height // im.height
249
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
250
+
251
+ resized = resize(im, src_w, src_h)
252
+ res = Image.new("RGB", (width, height))
253
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
254
+
255
+ if ratio < src_ratio:
256
+ fill_height = height // 2 - src_h // 2
257
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
258
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
259
+ elif ratio > src_ratio:
260
+ fill_width = width // 2 - src_w // 2
261
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
262
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
263
+
264
+ return res
265
+
266
+
267
+ invalid_filename_chars = '<>:"/\\|?*\n'
268
+ invalid_filename_prefix = ' '
269
+ invalid_filename_postfix = ' .'
270
+ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
271
+ max_filename_part_length = 128
272
+
273
+
274
+ def sanitize_filename_part(text, replace_spaces=True):
275
+ if replace_spaces:
276
+ text = text.replace(' ', '_')
277
+
278
+ text = text.translate({ord(x): '_' for x in invalid_filename_chars})
279
+ text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
280
+ text = text.rstrip(invalid_filename_postfix)
281
+ return text
282
+
283
+
284
+ def apply_filename_pattern(x, p, seed, prompt):
285
+ max_prompt_words = opts.directories_max_prompt_words
286
+
287
+ if seed is not None:
288
+ x = x.replace("[seed]", str(seed))
289
+
290
+ if p is not None:
291
+ x = x.replace("[steps]", str(p.steps))
292
+ x = x.replace("[cfg]", str(p.cfg_scale))
293
+ x = x.replace("[width]", str(p.width))
294
+ x = x.replace("[height]", str(p.height))
295
+ x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
296
+ x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
297
+
298
+ x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash))
299
+ x = x.replace("[date]", datetime.date.today().isoformat())
300
+ x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
301
+ x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp))
302
+
303
+ # Apply [prompt] at last. Because it may contain any replacement word.^M
304
+ if prompt is not None:
305
+ x = x.replace("[prompt]", sanitize_filename_part(prompt))
306
+ if "[prompt_no_styles]" in x:
307
+ prompt_no_style = prompt
308
+ for style in shared.prompt_styles.get_style_prompts(p.styles):
309
+ if len(style) > 0:
310
+ style_parts = [y for y in style.split("{prompt}")]
311
+ for part in style_parts:
312
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
313
+ prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
314
+ x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
315
+
316
+ x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
317
+ if "[prompt_words]" in x:
318
+ words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
319
+ if len(words) == 0:
320
+ words = ["empty"]
321
+ x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
322
+
323
+ if cmd_opts.hide_ui_dir_config:
324
+ x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
325
+
326
+ return x
327
+
328
+
329
+ def get_next_sequence_number(path, basename):
330
+ """
331
+ Determines and returns the next sequence number to use when saving an image in the specified directory.
332
+
333
+ The sequence starts at 0.
334
+ """
335
+ result = -1
336
+ if basename != '':
337
+ basename = basename + "-"
338
+
339
+ prefix_length = len(basename)
340
+ for p in os.listdir(path):
341
+ if p.startswith(basename):
342
+ l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
343
+ try:
344
+ result = max(int(l[0]), result)
345
+ except ValueError:
346
+ pass
347
+
348
+ return result + 1
349
+
350
+
351
+ def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
352
+ if short_filename or prompt is None or seed is None:
353
+ file_decoration = ""
354
+ elif opts.save_to_dirs:
355
+ file_decoration = opts.samples_filename_pattern or "[seed]"
356
+ else:
357
+ file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
358
+
359
+ if file_decoration != "":
360
+ file_decoration = "-" + file_decoration.lower()
361
+
362
+ file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
363
+
364
+ if extension == 'png' and opts.enable_pnginfo and info is not None:
365
+ pnginfo = PngImagePlugin.PngInfo()
366
+
367
+ if existing_info is not None:
368
+ for k, v in existing_info.items():
369
+ pnginfo.add_text(k, str(v))
370
+
371
+ pnginfo.add_text(pnginfo_section_name, info)
372
+ else:
373
+ pnginfo = None
374
+
375
+ if save_to_dirs is None:
376
+ save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
377
+
378
+ if save_to_dirs:
379
+ dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
380
+ path = os.path.join(path, dirname)
381
+
382
+ os.makedirs(path, exist_ok=True)
383
+
384
+ if forced_filename is None:
385
+ basecount = get_next_sequence_number(path, basename)
386
+ fullfn = "a.png"
387
+ fullfn_without_extension = "a"
388
+ for i in range(500):
389
+ fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
390
+ fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
391
+ fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
392
+ if not os.path.exists(fullfn):
393
+ break
394
+ else:
395
+ fullfn = os.path.join(path, f"{forced_filename}.{extension}")
396
+ fullfn_without_extension = os.path.join(path, forced_filename)
397
+
398
+ def exif_bytes():
399
+ return piexif.dump({
400
+ "Exif": {
401
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
402
+ },
403
+ })
404
+
405
+ if extension.lower() in ("jpg", "jpeg", "webp"):
406
+ image.save(fullfn, quality=opts.jpeg_quality)
407
+ if opts.enable_pnginfo and info is not None:
408
+ piexif.insert(exif_bytes(), fullfn)
409
+ else:
410
+ image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
411
+
412
+ target_side_length = 4000
413
+ oversize = image.width > target_side_length or image.height > target_side_length
414
+ if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
415
+ ratio = image.width / image.height
416
+
417
+ if oversize and ratio > 1:
418
+ image = image.resize((target_side_length, image.height * target_side_length // image.width), LANCZOS)
419
+ elif oversize:
420
+ image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
421
+
422
+ image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality)
423
+ if opts.enable_pnginfo and info is not None:
424
+ piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
425
+
426
+ if opts.save_txt and info is not None:
427
+ with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
428
+ file.write(info + "\n")
429
+
430
+ return fullfn
modules/img2img.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import sys
4
+ import traceback
5
+
6
+ import numpy as np
7
+ from PIL import Image, ImageOps, ImageChops
8
+
9
+ from modules import devices
10
+ from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
11
+ from modules.shared import opts, state
12
+ import modules.shared as shared
13
+ import modules.processing as processing
14
+ from modules.ui import plaintext_to_html
15
+ import modules.images as images
16
+ import modules.scripts
17
+
18
+
19
+ def process_batch(p, input_dir, output_dir, args):
20
+ processing.fix_seed(p)
21
+
22
+ images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
23
+
24
+ print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
25
+
26
+ save_normally = output_dir == ''
27
+
28
+ p.do_not_save_grid = True
29
+ p.do_not_save_samples = not save_normally
30
+
31
+ state.job_count = len(images) * p.n_iter
32
+
33
+ for i, image in enumerate(images):
34
+ state.job = f"{i+1} out of {len(images)}"
35
+
36
+ if state.interrupted:
37
+ break
38
+
39
+ img = Image.open(image)
40
+ p.init_images = [img] * p.batch_size
41
+
42
+ proc = modules.scripts.scripts_img2img.run(p, *args)
43
+ if proc is None:
44
+ proc = process_images(p)
45
+
46
+ for n, processed_image in enumerate(proc.images):
47
+ filename = os.path.basename(image)
48
+
49
+ if n > 0:
50
+ left, right = os.path.splitext(filename)
51
+ filename = f"{left}-{n}{right}"
52
+
53
+ if not save_normally:
54
+ processed_image.save(os.path.join(output_dir, filename))
55
+
56
+
57
+ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
58
+ is_inpaint = mode == 1
59
+ is_batch = mode == 2
60
+
61
+ if is_inpaint:
62
+ if mask_mode == 0:
63
+ image = init_img_with_mask['image']
64
+ mask = init_img_with_mask['mask']
65
+ alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
66
+ mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
67
+ image = image.convert('RGB')
68
+ else:
69
+ image = init_img_inpaint
70
+ mask = init_mask_inpaint
71
+ else:
72
+ image = init_img
73
+ mask = None
74
+
75
+ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
76
+
77
+ p = StableDiffusionProcessingImg2Img(
78
+ sd_model=shared.sd_model,
79
+ outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
80
+ outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
81
+ prompt=prompt,
82
+ negative_prompt=negative_prompt,
83
+ styles=[prompt_style, prompt_style2],
84
+ seed=seed,
85
+ subseed=subseed,
86
+ subseed_strength=subseed_strength,
87
+ seed_resize_from_h=seed_resize_from_h,
88
+ seed_resize_from_w=seed_resize_from_w,
89
+ seed_enable_extras=seed_enable_extras,
90
+ sampler_index=sampler_index,
91
+ batch_size=batch_size,
92
+ n_iter=n_iter,
93
+ steps=steps,
94
+ cfg_scale=cfg_scale,
95
+ width=width,
96
+ height=height,
97
+ restore_faces=restore_faces,
98
+ tiling=tiling,
99
+ init_images=[image],
100
+ mask=mask,
101
+ mask_blur=mask_blur,
102
+ inpainting_fill=inpainting_fill,
103
+ resize_mode=resize_mode,
104
+ denoising_strength=denoising_strength,
105
+ inpaint_full_res=inpaint_full_res,
106
+ inpaint_full_res_padding=inpaint_full_res_padding,
107
+ inpainting_mask_invert=inpainting_mask_invert,
108
+ )
109
+
110
+ if shared.cmd_opts.enable_console_prompts:
111
+ print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
112
+
113
+ p.extra_generation_params["Mask blur"] = mask_blur
114
+
115
+ if is_batch:
116
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
117
+
118
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args)
119
+
120
+ processed = Processed(p, [], p.seed, "")
121
+ else:
122
+ processed = modules.scripts.scripts_img2img.run(p, *args)
123
+ if processed is None:
124
+ processed = process_images(p)
125
+
126
+ shared.total_tqdm.clear()
127
+
128
+ generation_info_js = processed.js()
129
+ if opts.samples_log_stdout:
130
+ print(generation_info_js)
131
+
132
+ if opts.do_not_show_images:
133
+ processed.images = []
134
+
135
+ return processed.images, generation_info_js, plaintext_to_html(processed.info)
modules/interrogate.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import os
3
+ import sys
4
+ import traceback
5
+ from collections import namedtuple
6
+ import re
7
+
8
+ import torch
9
+
10
+ from torchvision import transforms
11
+ from torchvision.transforms.functional import InterpolationMode
12
+
13
+ import modules.shared as shared
14
+ from modules import devices, paths, lowvram
15
+
16
+ blip_image_eval_size = 384
17
+ blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
18
+ clip_model_name = 'ViT-L/14'
19
+
20
+ Category = namedtuple("Category", ["name", "topn", "items"])
21
+
22
+ re_topn = re.compile(r"\.top(\d+)\.")
23
+
24
+
25
+ class InterrogateModels:
26
+ blip_model = None
27
+ clip_model = None
28
+ clip_preprocess = None
29
+ categories = None
30
+ dtype = None
31
+
32
+ def __init__(self, content_dir):
33
+ self.categories = []
34
+
35
+ if os.path.exists(content_dir):
36
+ for filename in os.listdir(content_dir):
37
+ m = re_topn.search(filename)
38
+ topn = 1 if m is None else int(m.group(1))
39
+
40
+ with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
41
+ lines = [x.strip() for x in file.readlines()]
42
+
43
+ self.categories.append(Category(name=filename, topn=topn, items=lines))
44
+
45
+ def load_blip_model(self):
46
+ import models.blip
47
+
48
+ blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
49
+ blip_model.eval()
50
+
51
+ return blip_model
52
+
53
+ def load_clip_model(self):
54
+ import clip
55
+
56
+ model, preprocess = clip.load(clip_model_name)
57
+ model.eval()
58
+ model = model.to(shared.device)
59
+
60
+ return model, preprocess
61
+
62
+ def load(self):
63
+ if self.blip_model is None:
64
+ self.blip_model = self.load_blip_model()
65
+ if not shared.cmd_opts.no_half:
66
+ self.blip_model = self.blip_model.half()
67
+
68
+ self.blip_model = self.blip_model.to(shared.device)
69
+
70
+ if self.clip_model is None:
71
+ self.clip_model, self.clip_preprocess = self.load_clip_model()
72
+ if not shared.cmd_opts.no_half:
73
+ self.clip_model = self.clip_model.half()
74
+
75
+ self.clip_model = self.clip_model.to(shared.device)
76
+
77
+ self.dtype = next(self.clip_model.parameters()).dtype
78
+
79
+ def send_clip_to_ram(self):
80
+ if not shared.opts.interrogate_keep_models_in_memory:
81
+ if self.clip_model is not None:
82
+ self.clip_model = self.clip_model.to(devices.cpu)
83
+
84
+ def send_blip_to_ram(self):
85
+ if not shared.opts.interrogate_keep_models_in_memory:
86
+ if self.blip_model is not None:
87
+ self.blip_model = self.blip_model.to(devices.cpu)
88
+
89
+ def unload(self):
90
+ self.send_clip_to_ram()
91
+ self.send_blip_to_ram()
92
+
93
+ devices.torch_gc()
94
+
95
+ def rank(self, image_features, text_array, top_count=1):
96
+ import clip
97
+
98
+ if shared.opts.interrogate_clip_dict_limit != 0:
99
+ text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
100
+
101
+ top_count = min(top_count, len(text_array))
102
+ text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
103
+ text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
104
+ text_features /= text_features.norm(dim=-1, keepdim=True)
105
+
106
+ similarity = torch.zeros((1, len(text_array))).to(shared.device)
107
+ for i in range(image_features.shape[0]):
108
+ similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
109
+ similarity /= image_features.shape[0]
110
+
111
+ top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
112
+ return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
113
+
114
+ def generate_caption(self, pil_image):
115
+ gpu_image = transforms.Compose([
116
+ transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
117
+ transforms.ToTensor(),
118
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
119
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
120
+
121
+ with torch.no_grad():
122
+ caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
123
+
124
+ return caption[0]
125
+
126
+ def interrogate(self, pil_image):
127
+ res = None
128
+
129
+ try:
130
+
131
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
132
+ lowvram.send_everything_to_cpu()
133
+ devices.torch_gc()
134
+
135
+ self.load()
136
+
137
+ caption = self.generate_caption(pil_image)
138
+ self.send_blip_to_ram()
139
+ devices.torch_gc()
140
+
141
+ res = caption
142
+
143
+ cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
144
+
145
+ precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
146
+ with torch.no_grad(), precision_scope("cuda"):
147
+ image_features = self.clip_model.encode_image(cilp_image).type(self.dtype)
148
+
149
+ image_features /= image_features.norm(dim=-1, keepdim=True)
150
+
151
+ if shared.opts.interrogate_use_builtin_artists:
152
+ artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
153
+
154
+ res += ", " + artist[0]
155
+
156
+ for name, topn, items in self.categories:
157
+ matches = self.rank(image_features, items, top_count=topn)
158
+ for match, score in matches:
159
+ res += ", " + match
160
+
161
+ except Exception:
162
+ print(f"Error interrogating", file=sys.stderr)
163
+ print(traceback.format_exc(), file=sys.stderr)
164
+ res += "<error>"
165
+
166
+ self.unload()
167
+
168
+ return res
modules/ldsr_model.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ from basicsr.utils.download_util import load_file_from_url
6
+
7
+ from modules.upscaler import Upscaler, UpscalerData
8
+ from modules.ldsr_model_arch import LDSR
9
+ from modules import shared
10
+ from modules.paths import models_path
11
+
12
+
13
+ class UpscalerLDSR(Upscaler):
14
+ def __init__(self, user_path):
15
+ self.name = "LDSR"
16
+ self.model_path = os.path.join(models_path, self.name)
17
+ self.user_path = user_path
18
+ self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
19
+ self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
20
+ super().__init__()
21
+ scaler_data = UpscalerData("LDSR", None, self)
22
+ self.scalers = [scaler_data]
23
+
24
+ def load_model(self, path: str):
25
+ # Remove incorrect project.yaml file if too big
26
+ yaml_path = os.path.join(self.model_path, "project.yaml")
27
+ old_model_path = os.path.join(self.model_path, "model.pth")
28
+ new_model_path = os.path.join(self.model_path, "model.ckpt")
29
+ if os.path.exists(yaml_path):
30
+ statinfo = os.stat(yaml_path)
31
+ if statinfo.st_size >= 10485760:
32
+ print("Removing invalid LDSR YAML file.")
33
+ os.remove(yaml_path)
34
+ if os.path.exists(old_model_path):
35
+ print("Renaming model from model.pth to model.ckpt")
36
+ os.rename(old_model_path, new_model_path)
37
+ model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
38
+ file_name="model.ckpt", progress=True)
39
+ yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
40
+ file_name="project.yaml", progress=True)
41
+
42
+ try:
43
+ return LDSR(model, yaml)
44
+
45
+ except Exception:
46
+ print("Error importing LDSR:", file=sys.stderr)
47
+ print(traceback.format_exc(), file=sys.stderr)
48
+ return None
49
+
50
+ def do_upscale(self, img, path):
51
+ ldsr = self.load_model(path)
52
+ if ldsr is None:
53
+ print("NO LDSR!")
54
+ return img
55
+ ddim_steps = shared.opts.ldsr_steps
56
+ return ldsr.super_resolution(img, ddim_steps, self.scale)
modules/ldsr_model_arch.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import time
3
+ import warnings
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision
8
+ from PIL import Image
9
+ from einops import rearrange, repeat
10
+ from omegaconf import OmegaConf
11
+
12
+ from ldm.models.diffusion.ddim import DDIMSampler
13
+ from ldm.util import instantiate_from_config, ismap
14
+
15
+ warnings.filterwarnings("ignore", category=UserWarning)
16
+
17
+
18
+ # Create LDSR Class
19
+ class LDSR:
20
+ def load_model_from_config(self, half_attention):
21
+ print(f"Loading model from {self.modelPath}")
22
+ pl_sd = torch.load(self.modelPath, map_location="cpu")
23
+ sd = pl_sd["state_dict"]
24
+ config = OmegaConf.load(self.yamlPath)
25
+ model = instantiate_from_config(config.model)
26
+ model.load_state_dict(sd, strict=False)
27
+ model.cuda()
28
+ if half_attention:
29
+ model = model.half()
30
+
31
+ model.eval()
32
+ return {"model": model}
33
+
34
+ def __init__(self, model_path, yaml_path):
35
+ self.modelPath = model_path
36
+ self.yamlPath = yaml_path
37
+
38
+ @staticmethod
39
+ def run(model, selected_path, custom_steps, eta):
40
+ example = get_cond(selected_path)
41
+
42
+ n_runs = 1
43
+ guider = None
44
+ ckwargs = None
45
+ ddim_use_x0_pred = False
46
+ temperature = 1.
47
+ eta = eta
48
+ custom_shape = None
49
+
50
+ height, width = example["image"].shape[1:3]
51
+ split_input = height >= 128 and width >= 128
52
+
53
+ if split_input:
54
+ ks = 128
55
+ stride = 64
56
+ vqf = 4 #
57
+ model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
58
+ "vqf": vqf,
59
+ "patch_distributed_vq": True,
60
+ "tie_braker": False,
61
+ "clip_max_weight": 0.5,
62
+ "clip_min_weight": 0.01,
63
+ "clip_max_tie_weight": 0.5,
64
+ "clip_min_tie_weight": 0.01}
65
+ else:
66
+ if hasattr(model, "split_input_params"):
67
+ delattr(model, "split_input_params")
68
+
69
+ x_t = None
70
+ logs = None
71
+ for n in range(n_runs):
72
+ if custom_shape is not None:
73
+ x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
74
+ x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
75
+
76
+ logs = make_convolutional_sample(example, model,
77
+ custom_steps=custom_steps,
78
+ eta=eta, quantize_x0=False,
79
+ custom_shape=custom_shape,
80
+ temperature=temperature, noise_dropout=0.,
81
+ corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
82
+ ddim_use_x0_pred=ddim_use_x0_pred
83
+ )
84
+ return logs
85
+
86
+ def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
87
+ model = self.load_model_from_config(half_attention)
88
+
89
+ # Run settings
90
+ diffusion_steps = int(steps)
91
+ eta = 1.0
92
+
93
+ down_sample_method = 'Lanczos'
94
+
95
+ gc.collect()
96
+ torch.cuda.empty_cache()
97
+
98
+ im_og = image
99
+ width_og, height_og = im_og.size
100
+ # If we can adjust the max upscale size, then the 4 below should be our variable
101
+ down_sample_rate = target_scale / 4
102
+ wd = width_og * down_sample_rate
103
+ hd = height_og * down_sample_rate
104
+ width_downsampled_pre = int(wd)
105
+ height_downsampled_pre = int(hd)
106
+
107
+ if down_sample_rate != 1:
108
+ print(
109
+ f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
110
+ im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
111
+ else:
112
+ print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
113
+ logs = self.run(model["model"], im_og, diffusion_steps, eta)
114
+
115
+ sample = logs["sample"]
116
+ sample = sample.detach().cpu()
117
+ sample = torch.clamp(sample, -1., 1.)
118
+ sample = (sample + 1.) / 2. * 255
119
+ sample = sample.numpy().astype(np.uint8)
120
+ sample = np.transpose(sample, (0, 2, 3, 1))
121
+ a = Image.fromarray(sample[0])
122
+
123
+ del model
124
+ gc.collect()
125
+ torch.cuda.empty_cache()
126
+ return a
127
+
128
+
129
+ def get_cond(selected_path):
130
+ example = dict()
131
+ up_f = 4
132
+ c = selected_path.convert('RGB')
133
+ c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
134
+ c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
135
+ antialias=True)
136
+ c_up = rearrange(c_up, '1 c h w -> 1 h w c')
137
+ c = rearrange(c, '1 c h w -> 1 h w c')
138
+ c = 2. * c - 1.
139
+
140
+ c = c.to(torch.device("cuda"))
141
+ example["LR_image"] = c
142
+ example["image"] = c_up
143
+
144
+ return example
145
+
146
+
147
+ @torch.no_grad()
148
+ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
149
+ mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
150
+ corrector_kwargs=None, x_t=None
151
+ ):
152
+ ddim = DDIMSampler(model)
153
+ bs = shape[0]
154
+ shape = shape[1:]
155
+ print(f"Sampling with eta = {eta}; steps: {steps}")
156
+ samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
157
+ normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
158
+ mask=mask, x0=x0, temperature=temperature, verbose=False,
159
+ score_corrector=score_corrector,
160
+ corrector_kwargs=corrector_kwargs, x_t=x_t)
161
+
162
+ return samples, intermediates
163
+
164
+
165
+ @torch.no_grad()
166
+ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
167
+ corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
168
+ log = dict()
169
+
170
+ z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
171
+ return_first_stage_outputs=True,
172
+ force_c_encode=not (hasattr(model, 'split_input_params')
173
+ and model.cond_stage_key == 'coordinates_bbox'),
174
+ return_original_cond=True)
175
+
176
+ if custom_shape is not None:
177
+ z = torch.randn(custom_shape)
178
+ print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
179
+
180
+ z0 = None
181
+
182
+ log["input"] = x
183
+ log["reconstruction"] = xrec
184
+
185
+ if ismap(xc):
186
+ log["original_conditioning"] = model.to_rgb(xc)
187
+ if hasattr(model, 'cond_stage_key'):
188
+ log[model.cond_stage_key] = model.to_rgb(xc)
189
+
190
+ else:
191
+ log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
192
+ if model.cond_stage_model:
193
+ log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
194
+ if model.cond_stage_key == 'class_label':
195
+ log[model.cond_stage_key] = xc[model.cond_stage_key]
196
+
197
+ with model.ema_scope("Plotting"):
198
+ t0 = time.time()
199
+
200
+ sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
201
+ eta=eta,
202
+ quantize_x0=quantize_x0, mask=None, x0=z0,
203
+ temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
204
+ x_t=x_T)
205
+ t1 = time.time()
206
+
207
+ if ddim_use_x0_pred:
208
+ sample = intermediates['pred_x0'][-1]
209
+
210
+ x_sample = model.decode_first_stage(sample)
211
+
212
+ try:
213
+ x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
214
+ log["sample_noquant"] = x_sample_noquant
215
+ log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
216
+ except:
217
+ pass
218
+
219
+ log["sample"] = x_sample
220
+ log["time"] = t1 - t0
221
+
222
+ return log
modules/lowvram.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from modules.devices import get_optimal_device
3
+
4
+ module_in_gpu = None
5
+ cpu = torch.device("cpu")
6
+ device = gpu = get_optimal_device()
7
+
8
+
9
+ def send_everything_to_cpu():
10
+ global module_in_gpu
11
+
12
+ if module_in_gpu is not None:
13
+ module_in_gpu.to(cpu)
14
+
15
+ module_in_gpu = None
16
+
17
+
18
+ def setup_for_low_vram(sd_model, use_medvram):
19
+ parents = {}
20
+
21
+ def send_me_to_gpu(module, _):
22
+ """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
23
+ we add this as forward_pre_hook to a lot of modules and this way all but one of them will
24
+ be in CPU
25
+ """
26
+ global module_in_gpu
27
+
28
+ module = parents.get(module, module)
29
+
30
+ if module_in_gpu == module:
31
+ return
32
+
33
+ if module_in_gpu is not None:
34
+ module_in_gpu.to(cpu)
35
+
36
+ module.to(gpu)
37
+ module_in_gpu = module
38
+
39
+ # see below for register_forward_pre_hook;
40
+ # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
41
+ # useless here, and we just replace those methods
42
+ def first_stage_model_encode_wrap(self, encoder, x):
43
+ send_me_to_gpu(self, None)
44
+ return encoder(x)
45
+
46
+ def first_stage_model_decode_wrap(self, decoder, z):
47
+ send_me_to_gpu(self, None)
48
+ return decoder(z)
49
+
50
+ # remove three big modules, cond, first_stage, and unet from the model and then
51
+ # send the model to GPU. Then put modules back. the modules will be in CPU.
52
+ stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
53
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
54
+ sd_model.to(device)
55
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
56
+
57
+ # register hooks for those the first two models
58
+ sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
59
+ sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
60
+ sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
61
+ sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
62
+ parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
63
+
64
+ if use_medvram:
65
+ sd_model.model.register_forward_pre_hook(send_me_to_gpu)
66
+ else:
67
+ diff_model = sd_model.model.diffusion_model
68
+
69
+ # the third remaining model is still too big for 4 GB, so we also do the same for its submodules
70
+ # so that only one of them is in GPU at a time
71
+ stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
72
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
73
+ sd_model.model.to(device)
74
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
75
+
76
+ # install hooks for bits of third model
77
+ diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
78
+ for block in diff_model.input_blocks:
79
+ block.register_forward_pre_hook(send_me_to_gpu)
80
+ diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
81
+ for block in diff_model.output_blocks:
82
+ block.register_forward_pre_hook(send_me_to_gpu)
modules/masking.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFilter, ImageOps
2
+
3
+
4
+ def get_crop_region(mask, pad=0):
5
+ """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
6
+ For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
7
+
8
+ h, w = mask.shape
9
+
10
+ crop_left = 0
11
+ for i in range(w):
12
+ if not (mask[:, i] == 0).all():
13
+ break
14
+ crop_left += 1
15
+
16
+ crop_right = 0
17
+ for i in reversed(range(w)):
18
+ if not (mask[:, i] == 0).all():
19
+ break
20
+ crop_right += 1
21
+
22
+ crop_top = 0
23
+ for i in range(h):
24
+ if not (mask[i] == 0).all():
25
+ break
26
+ crop_top += 1
27
+
28
+ crop_bottom = 0
29
+ for i in reversed(range(h)):
30
+ if not (mask[i] == 0).all():
31
+ break
32
+ crop_bottom += 1
33
+
34
+ return (
35
+ int(max(crop_left-pad, 0)),
36
+ int(max(crop_top-pad, 0)),
37
+ int(min(w - crop_right + pad, w)),
38
+ int(min(h - crop_bottom + pad, h))
39
+ )
40
+
41
+
42
+ def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
43
+ """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
44
+ for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
45
+
46
+ x1, y1, x2, y2 = crop_region
47
+
48
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
49
+ ratio_processing = processing_width / processing_height
50
+
51
+ if ratio_crop_region > ratio_processing:
52
+ desired_height = (x2 - x1) * ratio_processing
53
+ desired_height_diff = int(desired_height - (y2-y1))
54
+ y1 -= desired_height_diff//2
55
+ y2 += desired_height_diff - desired_height_diff//2
56
+ if y2 >= image_height:
57
+ diff = y2 - image_height
58
+ y2 -= diff
59
+ y1 -= diff
60
+ if y1 < 0:
61
+ y2 -= y1
62
+ y1 -= y1
63
+ if y2 >= image_height:
64
+ y2 = image_height
65
+ else:
66
+ desired_width = (y2 - y1) * ratio_processing
67
+ desired_width_diff = int(desired_width - (x2-x1))
68
+ x1 -= desired_width_diff//2
69
+ x2 += desired_width_diff - desired_width_diff//2
70
+ if x2 >= image_width:
71
+ diff = x2 - image_width
72
+ x2 -= diff
73
+ x1 -= diff
74
+ if x1 < 0:
75
+ x2 -= x1
76
+ x1 -= x1
77
+ if x2 >= image_width:
78
+ x2 = image_width
79
+
80
+ return x1, y1, x2, y2
81
+
82
+
83
+ def fill(image, mask):
84
+ """fills masked regions with colors from image using blur. Not extremely effective."""
85
+
86
+ image_mod = Image.new('RGBA', (image.width, image.height))
87
+
88
+ image_masked = Image.new('RGBa', (image.width, image.height))
89
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
90
+
91
+ image_masked = image_masked.convert('RGBa')
92
+
93
+ for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
94
+ blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
95
+ for _ in range(repeats):
96
+ image_mod.alpha_composite(blurred)
97
+
98
+ return image_mod.convert("RGB")
99
+
modules/memmon.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ from collections import defaultdict
4
+
5
+ import torch
6
+
7
+
8
+ class MemUsageMonitor(threading.Thread):
9
+ run_flag = None
10
+ device = None
11
+ disabled = False
12
+ opts = None
13
+ data = None
14
+
15
+ def __init__(self, name, device, opts):
16
+ threading.Thread.__init__(self)
17
+ self.name = name
18
+ self.device = device
19
+ self.opts = opts
20
+
21
+ self.daemon = True
22
+ self.run_flag = threading.Event()
23
+ self.data = defaultdict(int)
24
+
25
+ try:
26
+ torch.cuda.mem_get_info()
27
+ torch.cuda.memory_stats(self.device)
28
+ except Exception as e: # AMD or whatever
29
+ print(f"Warning: caught exception '{e}', memory monitor disabled")
30
+ self.disabled = True
31
+
32
+ def run(self):
33
+ if self.disabled:
34
+ return
35
+
36
+ while True:
37
+ self.run_flag.wait()
38
+
39
+ torch.cuda.reset_peak_memory_stats()
40
+ self.data.clear()
41
+
42
+ if self.opts.memmon_poll_rate <= 0:
43
+ self.run_flag.clear()
44
+ continue
45
+
46
+ self.data["min_free"] = torch.cuda.mem_get_info()[0]
47
+
48
+ while self.run_flag.is_set():
49
+ free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
50
+ self.data["min_free"] = min(self.data["min_free"], free)
51
+
52
+ time.sleep(1 / self.opts.memmon_poll_rate)
53
+
54
+ def dump_debug(self):
55
+ print(self, 'recorded data:')
56
+ for k, v in self.read().items():
57
+ print(k, -(v // -(1024 ** 2)))
58
+
59
+ print(self, 'raw torch memory stats:')
60
+ tm = torch.cuda.memory_stats(self.device)
61
+ for k, v in tm.items():
62
+ if 'bytes' not in k:
63
+ continue
64
+ print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
65
+
66
+ print(torch.cuda.memory_summary())
67
+
68
+ def monitor(self):
69
+ self.run_flag.set()
70
+
71
+ def read(self):
72
+ if not self.disabled:
73
+ free, total = torch.cuda.mem_get_info()
74
+ self.data["total"] = total
75
+
76
+ torch_stats = torch.cuda.memory_stats(self.device)
77
+ self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
78
+ self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
79
+ self.data["system_peak"] = total - self.data["min_free"]
80
+
81
+ return self.data
82
+
83
+ def stop(self):
84
+ self.run_flag.clear()
85
+ return self.read()
modules/modelloader.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import shutil
4
+ import importlib
5
+ from urllib.parse import urlparse
6
+
7
+ from basicsr.utils.download_util import load_file_from_url
8
+ from modules import shared
9
+ from modules.upscaler import Upscaler
10
+ from modules.paths import script_path, models_path
11
+
12
+
13
+ def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
14
+ """
15
+ A one-and done loader to try finding the desired models in specified directories.
16
+
17
+ @param download_name: Specify to download from model_url immediately.
18
+ @param model_url: If no other models are found, this will be downloaded on upscale.
19
+ @param model_path: The location to store/find models in.
20
+ @param command_path: A command-line argument to search for models in first.
21
+ @param ext_filter: An optional list of filename extensions to filter by
22
+ @return: A list of paths containing the desired model(s)
23
+ """
24
+ output = []
25
+
26
+ if ext_filter is None:
27
+ ext_filter = []
28
+
29
+ try:
30
+ places = []
31
+
32
+ if command_path is not None and command_path != model_path:
33
+ pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
34
+ if os.path.exists(pretrained_path):
35
+ print(f"Appending path: {pretrained_path}")
36
+ places.append(pretrained_path)
37
+ elif os.path.exists(command_path):
38
+ places.append(command_path)
39
+
40
+ places.append(model_path)
41
+
42
+ for place in places:
43
+ if os.path.exists(place):
44
+ for file in glob.iglob(place + '**/**', recursive=True):
45
+ full_path = file
46
+ if os.path.isdir(full_path):
47
+ continue
48
+ if len(ext_filter) != 0:
49
+ model_name, extension = os.path.splitext(file)
50
+ if extension not in ext_filter:
51
+ continue
52
+ if file not in output:
53
+ output.append(full_path)
54
+
55
+ if model_url is not None and len(output) == 0:
56
+ if download_name is not None:
57
+ dl = load_file_from_url(model_url, model_path, True, download_name)
58
+ output.append(dl)
59
+ else:
60
+ output.append(model_url)
61
+
62
+ except Exception:
63
+ pass
64
+
65
+ return output
66
+
67
+
68
+ def friendly_name(file: str):
69
+ if "http" in file:
70
+ file = urlparse(file).path
71
+
72
+ file = os.path.basename(file)
73
+ model_name, extension = os.path.splitext(file)
74
+ return model_name
75
+
76
+
77
+ def cleanup_models():
78
+ # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
79
+ # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
80
+ # somehow auto-register and just do these things...
81
+ root_path = script_path
82
+ src_path = models_path
83
+ dest_path = os.path.join(models_path, "Stable-diffusion")
84
+ move_files(src_path, dest_path, ".ckpt")
85
+ src_path = os.path.join(root_path, "ESRGAN")
86
+ dest_path = os.path.join(models_path, "ESRGAN")
87
+ move_files(src_path, dest_path)
88
+ src_path = os.path.join(root_path, "gfpgan")
89
+ dest_path = os.path.join(models_path, "GFPGAN")
90
+ move_files(src_path, dest_path)
91
+ src_path = os.path.join(root_path, "SwinIR")
92
+ dest_path = os.path.join(models_path, "SwinIR")
93
+ move_files(src_path, dest_path)
94
+ src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
95
+ dest_path = os.path.join(models_path, "LDSR")
96
+ move_files(src_path, dest_path)
97
+
98
+
99
+ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
100
+ try:
101
+ if not os.path.exists(dest_path):
102
+ os.makedirs(dest_path)
103
+ if os.path.exists(src_path):
104
+ for file in os.listdir(src_path):
105
+ fullpath = os.path.join(src_path, file)
106
+ if os.path.isfile(fullpath):
107
+ if ext_filter is not None:
108
+ if ext_filter not in file:
109
+ continue
110
+ print(f"Moving {file} from {src_path} to {dest_path}.")
111
+ try:
112
+ shutil.move(fullpath, dest_path)
113
+ except:
114
+ pass
115
+ if len(os.listdir(src_path)) == 0:
116
+ print(f"Removing empty folder: {src_path}")
117
+ shutil.rmtree(src_path, True)
118
+ except:
119
+ pass
120
+
121
+
122
+ def load_upscalers():
123
+ sd = shared.script_path
124
+ # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
125
+ # so we'll try to import any _model.py files before looking in __subclasses__
126
+ modules_dir = os.path.join(sd, "modules")
127
+ for file in os.listdir(modules_dir):
128
+ if "_model.py" in file:
129
+ model_name = file.replace("_model.py", "")
130
+ full_model = f"modules.{model_name}_model"
131
+ try:
132
+ importlib.import_module(full_model)
133
+ except:
134
+ pass
135
+ datas = []
136
+ c_o = vars(shared.cmd_opts)
137
+ for cls in Upscaler.__subclasses__():
138
+ name = cls.__name__
139
+ module_name = cls.__module__
140
+ module = importlib.import_module(module_name)
141
+ class_ = getattr(module, name)
142
+ cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
143
+ opt_string = None
144
+ try:
145
+ if cmd_name in c_o:
146
+ opt_string = c_o[cmd_name]
147
+ except:
148
+ pass
149
+ scaler = class_(opt_string)
150
+ for child in scaler.scalers:
151
+ datas.append(child)
152
+
153
+ shared.sd_upscalers = datas
modules/paths.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
6
+ models_path = os.path.join(script_path, "models")
7
+ sys.path.insert(0, script_path)
8
+
9
+ # search for directory of stable diffusion in following places
10
+ sd_path = None
11
+ possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
12
+ for possible_sd_path in possible_sd_paths:
13
+ if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
14
+ sd_path = os.path.abspath(possible_sd_path)
15
+
16
+ assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
17
+
18
+ path_dirs = [
19
+ (sd_path, 'ldm', 'Stable Diffusion', []),
20
+ (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
21
+ (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
22
+ (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
23
+ (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
24
+ ]
25
+
26
+ paths = {}
27
+
28
+ for d, must_exist, what, options in path_dirs:
29
+ must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
30
+ if not os.path.exists(must_exist_path):
31
+ print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
32
+ else:
33
+ d = os.path.abspath(d)
34
+ if "atstart" in options:
35
+ sys.path.insert(0, d)
36
+ else:
37
+ sys.path.append(d)
38
+ paths[what] = d
modules/processing.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import sys
5
+
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image, ImageFilter, ImageOps
9
+ import random
10
+ import cv2
11
+ from skimage import exposure
12
+
13
+ import modules.sd_hijack
14
+ from modules import devices, prompt_parser, masking, sd_samplers, lowvram
15
+ from modules.sd_hijack import model_hijack
16
+ from modules.shared import opts, cmd_opts, state
17
+ import modules.shared as shared
18
+ import modules.face_restoration
19
+ import modules.images as images
20
+ import modules.styles
21
+ import logging
22
+
23
+
24
+ # some of those options should not be changed at all because they would break the model, so I removed them from options.
25
+ opt_C = 4
26
+ opt_f = 8
27
+
28
+
29
+ def setup_color_correction(image):
30
+ logging.info("Calibrating color correction.")
31
+ correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
32
+ return correction_target
33
+
34
+
35
+ def apply_color_correction(correction, image):
36
+ logging.info("Applying color correction.")
37
+ image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
38
+ cv2.cvtColor(
39
+ np.asarray(image),
40
+ cv2.COLOR_RGB2LAB
41
+ ),
42
+ correction,
43
+ channel_axis=2
44
+ ), cv2.COLOR_LAB2RGB).astype("uint8"))
45
+
46
+ return image
47
+
48
+
49
+ class StableDiffusionProcessing:
50
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
51
+ self.sd_model = sd_model
52
+ self.outpath_samples: str = outpath_samples
53
+ self.outpath_grids: str = outpath_grids
54
+ self.prompt: str = prompt
55
+ self.prompt_for_display: str = None
56
+ self.negative_prompt: str = (negative_prompt or "")
57
+ self.styles: list = styles or []
58
+ self.seed: int = seed
59
+ self.subseed: int = subseed
60
+ self.subseed_strength: float = subseed_strength
61
+ self.seed_resize_from_h: int = seed_resize_from_h
62
+ self.seed_resize_from_w: int = seed_resize_from_w
63
+ self.sampler_index: int = sampler_index
64
+ self.batch_size: int = batch_size
65
+ self.n_iter: int = n_iter
66
+ self.steps: int = steps
67
+ self.cfg_scale: float = cfg_scale
68
+ self.width: int = width
69
+ self.height: int = height
70
+ self.restore_faces: bool = restore_faces
71
+ self.tiling: bool = tiling
72
+ self.do_not_save_samples: bool = do_not_save_samples
73
+ self.do_not_save_grid: bool = do_not_save_grid
74
+ self.extra_generation_params: dict = extra_generation_params or {}
75
+ self.overlay_images = overlay_images
76
+ self.eta = eta
77
+ self.paste_to = None
78
+ self.color_corrections = None
79
+ self.denoising_strength: float = 0
80
+ self.sampler_noise_scheduler_override = None
81
+ self.ddim_discretize = opts.ddim_discretize
82
+ self.s_churn = opts.s_churn
83
+ self.s_tmin = opts.s_tmin
84
+ self.s_tmax = float('inf') # not representable as a standard ui option
85
+ self.s_noise = opts.s_noise
86
+
87
+ if not seed_enable_extras:
88
+ self.subseed = -1
89
+ self.subseed_strength = 0
90
+ self.seed_resize_from_h = 0
91
+ self.seed_resize_from_w = 0
92
+
93
+ def init(self, all_prompts, all_seeds, all_subseeds):
94
+ pass
95
+
96
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
97
+ raise NotImplementedError()
98
+
99
+
100
+ class Processed:
101
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
102
+ self.images = images_list
103
+ self.prompt = p.prompt
104
+ self.negative_prompt = p.negative_prompt
105
+ self.seed = seed
106
+ self.subseed = subseed
107
+ self.subseed_strength = p.subseed_strength
108
+ self.info = info
109
+ self.width = p.width
110
+ self.height = p.height
111
+ self.sampler_index = p.sampler_index
112
+ self.sampler = sd_samplers.samplers[p.sampler_index].name
113
+ self.cfg_scale = p.cfg_scale
114
+ self.steps = p.steps
115
+ self.batch_size = p.batch_size
116
+ self.restore_faces = p.restore_faces
117
+ self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
118
+ self.sd_model_hash = shared.sd_model.sd_model_hash
119
+ self.seed_resize_from_w = p.seed_resize_from_w
120
+ self.seed_resize_from_h = p.seed_resize_from_h
121
+ self.denoising_strength = getattr(p, 'denoising_strength', None)
122
+ self.extra_generation_params = p.extra_generation_params
123
+ self.index_of_first_image = index_of_first_image
124
+ self.styles = p.styles
125
+ self.job_timestamp = state.job_timestamp
126
+
127
+ self.eta = p.eta
128
+ self.ddim_discretize = p.ddim_discretize
129
+ self.s_churn = p.s_churn
130
+ self.s_tmin = p.s_tmin
131
+ self.s_tmax = p.s_tmax
132
+ self.s_noise = p.s_noise
133
+ self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
134
+ self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
135
+ self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
136
+ self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
137
+ self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
138
+
139
+ self.all_prompts = all_prompts or [self.prompt]
140
+ self.all_seeds = all_seeds or [self.seed]
141
+ self.all_subseeds = all_subseeds or [self.subseed]
142
+ self.infotexts = infotexts or [info]
143
+
144
+ def js(self):
145
+ obj = {
146
+ "prompt": self.prompt,
147
+ "all_prompts": self.all_prompts,
148
+ "negative_prompt": self.negative_prompt,
149
+ "seed": self.seed,
150
+ "all_seeds": self.all_seeds,
151
+ "subseed": self.subseed,
152
+ "all_subseeds": self.all_subseeds,
153
+ "subseed_strength": self.subseed_strength,
154
+ "width": self.width,
155
+ "height": self.height,
156
+ "sampler_index": self.sampler_index,
157
+ "sampler": self.sampler,
158
+ "cfg_scale": self.cfg_scale,
159
+ "steps": self.steps,
160
+ "batch_size": self.batch_size,
161
+ "restore_faces": self.restore_faces,
162
+ "face_restoration_model": self.face_restoration_model,
163
+ "sd_model_hash": self.sd_model_hash,
164
+ "seed_resize_from_w": self.seed_resize_from_w,
165
+ "seed_resize_from_h": self.seed_resize_from_h,
166
+ "denoising_strength": self.denoising_strength,
167
+ "extra_generation_params": self.extra_generation_params,
168
+ "index_of_first_image": self.index_of_first_image,
169
+ "infotexts": self.infotexts,
170
+ "styles": self.styles,
171
+ "job_timestamp": self.job_timestamp,
172
+ }
173
+
174
+ return json.dumps(obj)
175
+
176
+ def infotext(self, p: StableDiffusionProcessing, index):
177
+ return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
178
+
179
+
180
+ # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
181
+ def slerp(val, low, high):
182
+ low_norm = low/torch.norm(low, dim=1, keepdim=True)
183
+ high_norm = high/torch.norm(high, dim=1, keepdim=True)
184
+ dot = (low_norm*high_norm).sum(1)
185
+
186
+ if dot.mean() > 0.9995:
187
+ return low * val + high * (1 - val)
188
+
189
+ omega = torch.acos(dot)
190
+ so = torch.sin(omega)
191
+ res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
192
+ return res
193
+
194
+
195
+ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
196
+ xs = []
197
+
198
+ # if we have multiple seeds, this means we are working with batch size>1; this then
199
+ # enables the generation of additional tensors with noise that the sampler will use during its processing.
200
+ # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
201
+ # produce the same images as with two batches [100], [101].
202
+ if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
203
+ sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
204
+ else:
205
+ sampler_noises = None
206
+
207
+ for i, seed in enumerate(seeds):
208
+ noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
209
+
210
+ subnoise = None
211
+ if subseeds is not None:
212
+ subseed = 0 if i >= len(subseeds) else subseeds[i]
213
+
214
+ subnoise = devices.randn(subseed, noise_shape)
215
+
216
+ # randn results depend on device; gpu and cpu get different results for same seed;
217
+ # the way I see it, it's better to do this on CPU, so that everyone gets same result;
218
+ # but the original script had it like this, so I do not dare change it for now because
219
+ # it will break everyone's seeds.
220
+ noise = devices.randn(seed, noise_shape)
221
+
222
+ if subnoise is not None:
223
+ noise = slerp(subseed_strength, noise, subnoise)
224
+
225
+ if noise_shape != shape:
226
+ x = devices.randn(seed, shape)
227
+ dx = (shape[2] - noise_shape[2]) // 2
228
+ dy = (shape[1] - noise_shape[1]) // 2
229
+ w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
230
+ h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
231
+ tx = 0 if dx < 0 else dx
232
+ ty = 0 if dy < 0 else dy
233
+ dx = max(-dx, 0)
234
+ dy = max(-dy, 0)
235
+
236
+ x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
237
+ noise = x
238
+
239
+ if sampler_noises is not None:
240
+ cnt = p.sampler.number_of_needed_noises(p)
241
+
242
+ for j in range(cnt):
243
+ sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
244
+
245
+ xs.append(noise)
246
+
247
+ if sampler_noises is not None:
248
+ p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
249
+
250
+ x = torch.stack(xs).to(shared.device)
251
+ return x
252
+
253
+
254
+ def get_fixed_seed(seed):
255
+ if seed is None or seed == '' or seed == -1:
256
+ return int(random.randrange(4294967294))
257
+
258
+ return seed
259
+
260
+
261
+ def fix_seed(p):
262
+ p.seed = get_fixed_seed(p.seed)
263
+ p.subseed = get_fixed_seed(p.subseed)
264
+
265
+
266
+ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
267
+ index = position_in_batch + iteration * p.batch_size
268
+
269
+ generation_params = {
270
+ "Steps": p.steps,
271
+ "Sampler": sd_samplers.samplers[p.sampler_index].name,
272
+ "CFG scale": p.cfg_scale,
273
+ "Seed": all_seeds[index],
274
+ "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
275
+ "Size": f"{p.width}x{p.height}",
276
+ "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
277
+ "Batch size": (None if p.batch_size < 2 else p.batch_size),
278
+ "Batch pos": (None if p.batch_size < 2 else position_in_batch),
279
+ "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
280
+ "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
281
+ "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
282
+ "Denoising strength": getattr(p, 'denoising_strength', None),
283
+ "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
284
+ }
285
+
286
+ generation_params.update(p.extra_generation_params)
287
+
288
+ generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
289
+
290
+ negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
291
+
292
+ return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
293
+
294
+
295
+ def process_images(p: StableDiffusionProcessing) -> Processed:
296
+ """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
297
+
298
+ if type(p.prompt) == list:
299
+ assert(len(p.prompt) > 0)
300
+ else:
301
+ assert p.prompt is not None
302
+
303
+ devices.torch_gc()
304
+
305
+ seed = get_fixed_seed(p.seed)
306
+ subseed = get_fixed_seed(p.subseed)
307
+
308
+ if p.outpath_samples is not None:
309
+ os.makedirs(p.outpath_samples, exist_ok=True)
310
+
311
+ if p.outpath_grids is not None:
312
+ os.makedirs(p.outpath_grids, exist_ok=True)
313
+
314
+ modules.sd_hijack.model_hijack.apply_circular(p.tiling)
315
+
316
+ comments = {}
317
+
318
+ shared.prompt_styles.apply_styles(p)
319
+
320
+ if type(p.prompt) == list:
321
+ all_prompts = p.prompt
322
+ else:
323
+ all_prompts = p.batch_size * p.n_iter * [p.prompt]
324
+
325
+ if type(seed) == list:
326
+ all_seeds = seed
327
+ else:
328
+ all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
329
+
330
+ if type(subseed) == list:
331
+ all_subseeds = subseed
332
+ else:
333
+ all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
334
+
335
+ def infotext(iteration=0, position_in_batch=0):
336
+ return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
337
+
338
+ if os.path.exists(cmd_opts.embeddings_dir):
339
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
340
+
341
+ infotexts = []
342
+ output_images = []
343
+
344
+ with torch.no_grad():
345
+ with devices.autocast():
346
+ p.init(all_prompts, all_seeds, all_subseeds)
347
+
348
+ if state.job_count == -1:
349
+ state.job_count = p.n_iter
350
+
351
+ for n in range(p.n_iter):
352
+ if state.interrupted:
353
+ break
354
+
355
+ prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
356
+ seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
357
+ subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
358
+
359
+ if (len(prompts) == 0):
360
+ break
361
+
362
+ #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
363
+ #c = p.sd_model.get_learned_conditioning(prompts)
364
+ with devices.autocast():
365
+ uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
366
+ c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
367
+
368
+ if len(model_hijack.comments) > 0:
369
+ for comment in model_hijack.comments:
370
+ comments[comment] = 1
371
+
372
+ if p.n_iter > 1:
373
+ shared.state.job = f"Batch {n+1} out of {p.n_iter}"
374
+
375
+ with devices.autocast():
376
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
377
+
378
+ if state.interrupted:
379
+
380
+ # if we are interruped, sample returns just noise
381
+ # use the image collected previously in sampler loop
382
+ samples_ddim = shared.state.current_latent
383
+
384
+ samples_ddim = samples_ddim.to(devices.dtype)
385
+
386
+ x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
387
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
388
+
389
+ del samples_ddim
390
+
391
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
392
+ lowvram.send_everything_to_cpu()
393
+
394
+ devices.torch_gc()
395
+
396
+ if opts.filter_nsfw:
397
+ import modules.safety as safety
398
+ x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
399
+
400
+ for i, x_sample in enumerate(x_samples_ddim):
401
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
402
+ x_sample = x_sample.astype(np.uint8)
403
+
404
+ if p.restore_faces:
405
+ if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
406
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
407
+
408
+ devices.torch_gc()
409
+
410
+ x_sample = modules.face_restoration.restore_faces(x_sample)
411
+ devices.torch_gc()
412
+
413
+ image = Image.fromarray(x_sample)
414
+
415
+ if p.color_corrections is not None and i < len(p.color_corrections):
416
+ if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
417
+ images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
418
+ image = apply_color_correction(p.color_corrections[i], image)
419
+
420
+ if p.overlay_images is not None and i < len(p.overlay_images):
421
+ overlay = p.overlay_images[i]
422
+
423
+ if p.paste_to is not None:
424
+ x, y, w, h = p.paste_to
425
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
426
+ image = images.resize_image(1, image, w, h)
427
+ base_image.paste(image, (x, y))
428
+ image = base_image
429
+
430
+ image = image.convert('RGBA')
431
+ image.alpha_composite(overlay)
432
+ image = image.convert('RGB')
433
+
434
+ if opts.samples_save and not p.do_not_save_samples:
435
+ images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
436
+
437
+ text = infotext(n, i)
438
+ infotexts.append(text)
439
+ image.info["parameters"] = text
440
+ output_images.append(image)
441
+
442
+ del x_samples_ddim
443
+
444
+ devices.torch_gc()
445
+
446
+ state.nextjob()
447
+
448
+ p.color_corrections = None
449
+
450
+ index_of_first_image = 0
451
+ unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
452
+ if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
453
+ grid = images.image_grid(output_images, p.batch_size)
454
+
455
+ if opts.return_grid:
456
+ text = infotext()
457
+ infotexts.insert(0, text)
458
+ grid.info["parameters"] = text
459
+ output_images.insert(0, grid)
460
+ index_of_first_image = 1
461
+
462
+ if opts.grid_save:
463
+ images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
464
+
465
+ devices.torch_gc()
466
+ return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
467
+
468
+
469
+ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
470
+ sampler = None
471
+ firstphase_width = 0
472
+ firstphase_height = 0
473
+ firstphase_width_truncated = 0
474
+ firstphase_height_truncated = 0
475
+
476
+ def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
477
+ super().__init__(**kwargs)
478
+ self.enable_hr = enable_hr
479
+ self.scale_latent = scale_latent
480
+ self.denoising_strength = denoising_strength
481
+
482
+ def init(self, all_prompts, all_seeds, all_subseeds):
483
+ if self.enable_hr:
484
+ if state.job_count == -1:
485
+ state.job_count = self.n_iter * 2
486
+ else:
487
+ state.job_count = state.job_count * 2
488
+
489
+ desired_pixel_count = 512 * 512
490
+ actual_pixel_count = self.width * self.height
491
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
492
+
493
+ self.firstphase_width = math.ceil(scale * self.width / 64) * 64
494
+ self.firstphase_height = math.ceil(scale * self.height / 64) * 64
495
+ self.firstphase_width_truncated = int(scale * self.width)
496
+ self.firstphase_height_truncated = int(scale * self.height)
497
+
498
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
499
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
500
+
501
+ if not self.enable_hr:
502
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
503
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
504
+ return samples
505
+
506
+ x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
507
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
508
+
509
+ truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
510
+ truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
511
+
512
+ samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
513
+
514
+ if self.scale_latent:
515
+ samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
516
+ else:
517
+ decoded_samples = self.sd_model.decode_first_stage(samples)
518
+
519
+ if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
520
+ decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
521
+ else:
522
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
523
+
524
+ batch_images = []
525
+ for i, x_sample in enumerate(lowres_samples):
526
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
527
+ x_sample = x_sample.astype(np.uint8)
528
+ image = Image.fromarray(x_sample)
529
+ image = images.resize_image(0, image, self.width, self.height)
530
+ image = np.array(image).astype(np.float32) / 255.0
531
+ image = np.moveaxis(image, 2, 0)
532
+ batch_images.append(image)
533
+
534
+ decoded_samples = torch.from_numpy(np.array(batch_images))
535
+ decoded_samples = decoded_samples.to(shared.device)
536
+ decoded_samples = 2. * decoded_samples - 1.
537
+
538
+ samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
539
+
540
+ shared.state.nextjob()
541
+
542
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
543
+
544
+ noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
545
+
546
+ # GC now before running the next img2img to prevent running out of memory
547
+ x = None
548
+ devices.torch_gc()
549
+
550
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
551
+
552
+ return samples
553
+
554
+
555
+ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
556
+ sampler = None
557
+
558
+ def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
559
+ super().__init__(**kwargs)
560
+
561
+ self.init_images = init_images
562
+ self.resize_mode: int = resize_mode
563
+ self.denoising_strength: float = denoising_strength
564
+ self.init_latent = None
565
+ self.image_mask = mask
566
+ #self.image_unblurred_mask = None
567
+ self.latent_mask = None
568
+ self.mask_for_overlay = None
569
+ self.mask_blur = mask_blur
570
+ self.inpainting_fill = inpainting_fill
571
+ self.inpaint_full_res = inpaint_full_res
572
+ self.inpaint_full_res_padding = inpaint_full_res_padding
573
+ self.inpainting_mask_invert = inpainting_mask_invert
574
+ self.mask = None
575
+ self.nmask = None
576
+
577
+ def init(self, all_prompts, all_seeds, all_subseeds):
578
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
579
+ crop_region = None
580
+
581
+ if self.image_mask is not None:
582
+ self.image_mask = self.image_mask.convert('L')
583
+
584
+ if self.inpainting_mask_invert:
585
+ self.image_mask = ImageOps.invert(self.image_mask)
586
+
587
+ #self.image_unblurred_mask = self.image_mask
588
+
589
+ if self.mask_blur > 0:
590
+ self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
591
+
592
+ if self.inpaint_full_res:
593
+ self.mask_for_overlay = self.image_mask
594
+ mask = self.image_mask.convert('L')
595
+ crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
596
+ crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
597
+ x1, y1, x2, y2 = crop_region
598
+
599
+ mask = mask.crop(crop_region)
600
+ self.image_mask = images.resize_image(2, mask, self.width, self.height)
601
+ self.paste_to = (x1, y1, x2-x1, y2-y1)
602
+ else:
603
+ self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
604
+ np_mask = np.array(self.image_mask)
605
+ np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
606
+ self.mask_for_overlay = Image.fromarray(np_mask)
607
+
608
+ self.overlay_images = []
609
+
610
+ latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
611
+
612
+ add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
613
+ if add_color_corrections:
614
+ self.color_corrections = []
615
+ imgs = []
616
+ for img in self.init_images:
617
+ image = img.convert("RGB")
618
+
619
+ if crop_region is None:
620
+ image = images.resize_image(self.resize_mode, image, self.width, self.height)
621
+
622
+ if self.image_mask is not None:
623
+ image_masked = Image.new('RGBa', (image.width, image.height))
624
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
625
+
626
+ self.overlay_images.append(image_masked.convert('RGBA'))
627
+
628
+ if crop_region is not None:
629
+ image = image.crop(crop_region)
630
+ image = images.resize_image(2, image, self.width, self.height)
631
+
632
+ if self.image_mask is not None:
633
+ if self.inpainting_fill != 1:
634
+ image = masking.fill(image, latent_mask)
635
+
636
+ if add_color_corrections:
637
+ self.color_corrections.append(setup_color_correction(image))
638
+
639
+ image = np.array(image).astype(np.float32) / 255.0
640
+ image = np.moveaxis(image, 2, 0)
641
+
642
+ imgs.append(image)
643
+
644
+ if len(imgs) == 1:
645
+ batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
646
+ if self.overlay_images is not None:
647
+ self.overlay_images = self.overlay_images * self.batch_size
648
+ elif len(imgs) <= self.batch_size:
649
+ self.batch_size = len(imgs)
650
+ batch_images = np.array(imgs)
651
+ else:
652
+ raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
653
+
654
+ image = torch.from_numpy(batch_images)
655
+ image = 2. * image - 1.
656
+ image = image.to(shared.device)
657
+
658
+ self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
659
+
660
+ if self.image_mask is not None:
661
+ init_mask = latent_mask
662
+ latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
663
+ latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
664
+ latmask = latmask[0]
665
+ latmask = np.around(latmask)
666
+ latmask = np.tile(latmask[None], (4, 1, 1))
667
+
668
+ self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
669
+ self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
670
+
671
+ # this needs to be fixed to be done in sample() using actual seeds for batches
672
+ if self.inpainting_fill == 2:
673
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
674
+ elif self.inpainting_fill == 3:
675
+ self.init_latent = self.init_latent * self.mask
676
+
677
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
678
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
679
+
680
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
681
+
682
+ if self.mask is not None:
683
+ samples = samples * self.nmask + self.init_latent * self.mask
684
+
685
+ del x
686
+ devices.torch_gc()
687
+
688
+ return samples
modules/prompt_parser.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import namedtuple
3
+ from typing import List
4
+ import lark
5
+
6
+ # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
7
+ # will be represented with prompt_schedule like this (assuming steps=100):
8
+ # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
9
+ # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
10
+ # [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
11
+ # [75, 'fantasy landscape with a lake and an oak in background masterful']
12
+ # [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
13
+
14
+ schedule_parser = lark.Lark(r"""
15
+ !start: (prompt | /[][():]/+)*
16
+ prompt: (emphasized | scheduled | plain | WHITESPACE)*
17
+ !emphasized: "(" prompt ")"
18
+ | "(" prompt ":" prompt ")"
19
+ | "[" prompt "]"
20
+ scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
21
+ WHITESPACE: /\s+/
22
+ plain: /([^\\\[\]():]|\\.)+/
23
+ %import common.SIGNED_NUMBER -> NUMBER
24
+ """)
25
+
26
+ def get_learned_conditioning_prompt_schedules(prompts, steps):
27
+ """
28
+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
29
+ >>> g("test")
30
+ [[10, 'test']]
31
+ >>> g("a [b:3]")
32
+ [[3, 'a '], [10, 'a b']]
33
+ >>> g("a [b: 3]")
34
+ [[3, 'a '], [10, 'a b']]
35
+ >>> g("a [[[b]]:2]")
36
+ [[2, 'a '], [10, 'a [[b]]']]
37
+ >>> g("[(a:2):3]")
38
+ [[3, ''], [10, '(a:2)']]
39
+ >>> g("a [b : c : 1] d")
40
+ [[1, 'a b d'], [10, 'a c d']]
41
+ >>> g("a[b:[c:d:2]:1]e")
42
+ [[1, 'abe'], [2, 'ace'], [10, 'ade']]
43
+ >>> g("a [unbalanced")
44
+ [[10, 'a [unbalanced']]
45
+ >>> g("a [b:.5] c")
46
+ [[5, 'a c'], [10, 'a b c']]
47
+ >>> g("a [{b|d{:.5] c") # not handling this right now
48
+ [[5, 'a c'], [10, 'a {b|d{ c']]
49
+ >>> g("((a][:b:c [d:3]")
50
+ [[3, '((a][:b:c '], [10, '((a][:b:c d']]
51
+ """
52
+
53
+ def collect_steps(steps, tree):
54
+ l = [steps]
55
+ class CollectSteps(lark.Visitor):
56
+ def scheduled(self, tree):
57
+ tree.children[-1] = float(tree.children[-1])
58
+ if tree.children[-1] < 1:
59
+ tree.children[-1] *= steps
60
+ tree.children[-1] = min(steps, int(tree.children[-1]))
61
+ l.append(tree.children[-1])
62
+ CollectSteps().visit(tree)
63
+ return sorted(set(l))
64
+
65
+ def at_step(step, tree):
66
+ class AtStep(lark.Transformer):
67
+ def scheduled(self, args):
68
+ before, after, _, when = args
69
+ yield before or () if step <= when else after
70
+ def start(self, args):
71
+ def flatten(x):
72
+ if type(x) == str:
73
+ yield x
74
+ else:
75
+ for gen in x:
76
+ yield from flatten(gen)
77
+ return ''.join(flatten(args))
78
+ def plain(self, args):
79
+ yield args[0].value
80
+ def __default__(self, data, children, meta):
81
+ for child in children:
82
+ yield from child
83
+ return AtStep().transform(tree)
84
+
85
+ def get_schedule(prompt):
86
+ try:
87
+ tree = schedule_parser.parse(prompt)
88
+ except lark.exceptions.LarkError as e:
89
+ if 0:
90
+ import traceback
91
+ traceback.print_exc()
92
+ return [[steps, prompt]]
93
+ return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
94
+
95
+ promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
96
+ return [promptdict[prompt] for prompt in prompts]
97
+
98
+
99
+ ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
100
+
101
+
102
+ def get_learned_conditioning(model, prompts, steps):
103
+ """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
104
+ and the sampling step at which this condition is to be replaced by the next one.
105
+
106
+ Input:
107
+ (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
108
+
109
+ Output:
110
+ [
111
+ [
112
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
113
+ ],
114
+ [
115
+ ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
116
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
117
+ ]
118
+ ]
119
+ """
120
+ res = []
121
+
122
+ prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
123
+ cache = {}
124
+
125
+ for prompt, prompt_schedule in zip(prompts, prompt_schedules):
126
+
127
+ cached = cache.get(prompt, None)
128
+ if cached is not None:
129
+ res.append(cached)
130
+ continue
131
+
132
+ texts = [x[1] for x in prompt_schedule]
133
+ conds = model.get_learned_conditioning(texts)
134
+
135
+ cond_schedule = []
136
+ for i, (end_at_step, text) in enumerate(prompt_schedule):
137
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
138
+
139
+ cache[prompt] = cond_schedule
140
+ res.append(cond_schedule)
141
+
142
+ return res
143
+
144
+
145
+ re_AND = re.compile(r"\bAND\b")
146
+ re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
147
+
148
+ def get_multicond_prompt_list(prompts):
149
+ res_indexes = []
150
+
151
+ prompt_flat_list = []
152
+ prompt_indexes = {}
153
+
154
+ for prompt in prompts:
155
+ subprompts = re_AND.split(prompt)
156
+
157
+ indexes = []
158
+ for subprompt in subprompts:
159
+ match = re_weight.search(subprompt)
160
+
161
+ text, weight = match.groups() if match is not None else (subprompt, 1.0)
162
+
163
+ weight = float(weight) if weight is not None else 1.0
164
+
165
+ index = prompt_indexes.get(text, None)
166
+ if index is None:
167
+ index = len(prompt_flat_list)
168
+ prompt_flat_list.append(text)
169
+ prompt_indexes[text] = index
170
+
171
+ indexes.append((index, weight))
172
+
173
+ res_indexes.append(indexes)
174
+
175
+ return res_indexes, prompt_flat_list, prompt_indexes
176
+
177
+
178
+ class ComposableScheduledPromptConditioning:
179
+ def __init__(self, schedules, weight=1.0):
180
+ self.schedules: List[ScheduledPromptConditioning] = schedules
181
+ self.weight: float = weight
182
+
183
+
184
+ class MulticondLearnedConditioning:
185
+ def __init__(self, shape, batch):
186
+ self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
187
+ self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
188
+
189
+ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
190
+ """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
191
+ For each prompt, the list is obtained by splitting the prompt using the AND separator.
192
+
193
+ https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
194
+ """
195
+
196
+ res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
197
+
198
+ learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
199
+
200
+ res = []
201
+ for indexes in res_indexes:
202
+ res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
203
+
204
+ return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
205
+
206
+
207
+ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
208
+ param = c[0][0].cond
209
+ res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
210
+ for i, cond_schedule in enumerate(c):
211
+ target_index = 0
212
+ for current, (end_at, cond) in enumerate(cond_schedule):
213
+ if current_step <= end_at:
214
+ target_index = current
215
+ break
216
+ res[i] = cond_schedule[target_index].cond
217
+
218
+ return res
219
+
220
+
221
+ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
222
+ param = c.batch[0][0].schedules[0].cond
223
+
224
+ tensors = []
225
+ conds_list = []
226
+
227
+ for batch_no, composable_prompts in enumerate(c.batch):
228
+ conds_for_batch = []
229
+
230
+ for cond_index, composable_prompt in enumerate(composable_prompts):
231
+ target_index = 0
232
+ for current, (end_at, cond) in enumerate(composable_prompt.schedules):
233
+ if current_step <= end_at:
234
+ target_index = current
235
+ break
236
+
237
+ conds_for_batch.append((len(tensors), composable_prompt.weight))
238
+ tensors.append(composable_prompt.schedules[target_index].cond)
239
+
240
+ conds_list.append(conds_for_batch)
241
+
242
+ return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
243
+
244
+
245
+ re_attention = re.compile(r"""
246
+ \\\(|
247
+ \\\)|
248
+ \\\[|
249
+ \\]|
250
+ \\\\|
251
+ \\|
252
+ \(|
253
+ \[|
254
+ :([+-]?[.\d]+)\)|
255
+ \)|
256
+ ]|
257
+ [^\\()\[\]:]+|
258
+ :
259
+ """, re.X)
260
+
261
+
262
+ def parse_prompt_attention(text):
263
+ """
264
+ Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
265
+ Accepted tokens are:
266
+ (abc) - increases attention to abc by a multiplier of 1.1
267
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
268
+ [abc] - decreases attention to abc by a multiplier of 1.1
269
+ \( - literal character '('
270
+ \[ - literal character '['
271
+ \) - literal character ')'
272
+ \] - literal character ']'
273
+ \\ - literal character '\'
274
+ anything else - just text
275
+
276
+ >>> parse_prompt_attention('normal text')
277
+ [['normal text', 1.0]]
278
+ >>> parse_prompt_attention('an (important) word')
279
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
280
+ >>> parse_prompt_attention('(unbalanced')
281
+ [['unbalanced', 1.1]]
282
+ >>> parse_prompt_attention('\(literal\]')
283
+ [['(literal]', 1.0]]
284
+ >>> parse_prompt_attention('(unnecessary)(parens)')
285
+ [['unnecessaryparens', 1.1]]
286
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
287
+ [['a ', 1.0],
288
+ ['house', 1.5730000000000004],
289
+ [' ', 1.1],
290
+ ['on', 1.0],
291
+ [' a ', 1.1],
292
+ ['hill', 0.55],
293
+ [', sun, ', 1.1],
294
+ ['sky', 1.4641000000000006],
295
+ ['.', 1.1]]
296
+ """
297
+
298
+ res = []
299
+ round_brackets = []
300
+ square_brackets = []
301
+
302
+ round_bracket_multiplier = 1.1
303
+ square_bracket_multiplier = 1 / 1.1
304
+
305
+ def multiply_range(start_position, multiplier):
306
+ for p in range(start_position, len(res)):
307
+ res[p][1] *= multiplier
308
+
309
+ for m in re_attention.finditer(text):
310
+ text = m.group(0)
311
+ weight = m.group(1)
312
+
313
+ if text.startswith('\\'):
314
+ res.append([text[1:], 1.0])
315
+ elif text == '(':
316
+ round_brackets.append(len(res))
317
+ elif text == '[':
318
+ square_brackets.append(len(res))
319
+ elif weight is not None and len(round_brackets) > 0:
320
+ multiply_range(round_brackets.pop(), float(weight))
321
+ elif text == ')' and len(round_brackets) > 0:
322
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
323
+ elif text == ']' and len(square_brackets) > 0:
324
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
325
+ else:
326
+ res.append([text, 1.0])
327
+
328
+ for pos in round_brackets:
329
+ multiply_range(pos, round_bracket_multiplier)
330
+
331
+ for pos in square_brackets:
332
+ multiply_range(pos, square_bracket_multiplier)
333
+
334
+ if len(res) == 0:
335
+ res = [["", 1.0]]
336
+
337
+ # merge runs of identical weights
338
+ i = 0
339
+ while i + 1 < len(res):
340
+ if res[i][1] == res[i + 1][1]:
341
+ res[i][0] += res[i + 1][0]
342
+ res.pop(i + 1)
343
+ else:
344
+ i += 1
345
+
346
+ return res
347
+
348
+ if __name__ == "__main__":
349
+ import doctest
350
+ doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
351
+ else:
352
+ import torch # doctest faster
modules/realesrgan_model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ from basicsr.utils.download_util import load_file_from_url
8
+ from realesrgan import RealESRGANer
9
+
10
+ from modules.upscaler import Upscaler, UpscalerData
11
+ from modules.paths import models_path
12
+ from modules.shared import cmd_opts, opts
13
+
14
+
15
+ class UpscalerRealESRGAN(Upscaler):
16
+ def __init__(self, path):
17
+ self.name = "RealESRGAN"
18
+ self.model_path = os.path.join(models_path, self.name)
19
+ self.user_path = path
20
+ super().__init__()
21
+ try:
22
+ from basicsr.archs.rrdbnet_arch import RRDBNet
23
+ from realesrgan import RealESRGANer
24
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
25
+ self.enable = True
26
+ self.scalers = []
27
+ scalers = self.load_models(path)
28
+ for scaler in scalers:
29
+ if scaler.name in opts.realesrgan_enabled_models:
30
+ self.scalers.append(scaler)
31
+
32
+ except Exception:
33
+ print("Error importing Real-ESRGAN:", file=sys.stderr)
34
+ print(traceback.format_exc(), file=sys.stderr)
35
+ self.enable = False
36
+ self.scalers = []
37
+
38
+ def do_upscale(self, img, path):
39
+ if not self.enable:
40
+ return img
41
+
42
+ info = self.load_model(path)
43
+ if not os.path.exists(info.data_path):
44
+ print("Unable to load RealESRGAN model: %s" % info.name)
45
+ return img
46
+
47
+ upsampler = RealESRGANer(
48
+ scale=info.scale,
49
+ model_path=info.data_path,
50
+ model=info.model(),
51
+ half=not cmd_opts.no_half,
52
+ tile=opts.ESRGAN_tile,
53
+ tile_pad=opts.ESRGAN_tile_overlap,
54
+ )
55
+
56
+ upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
57
+
58
+ image = Image.fromarray(upsampled)
59
+ return image
60
+
61
+ def load_model(self, path):
62
+ try:
63
+ info = None
64
+ for scaler in self.scalers:
65
+ if scaler.data_path == path:
66
+ info = scaler
67
+
68
+ if info is None:
69
+ print(f"Unable to find model info: {path}")
70
+ return None
71
+
72
+ model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
73
+ info.data_path = model_file
74
+ return info
75
+ except Exception as e:
76
+ print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
77
+ print(traceback.format_exc(), file=sys.stderr)
78
+ return None
79
+
80
+ def load_models(self, _):
81
+ return get_realesrgan_models(self)
82
+
83
+
84
+ def get_realesrgan_models(scaler):
85
+ try:
86
+ from basicsr.archs.rrdbnet_arch import RRDBNet
87
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
88
+ models = [
89
+ UpscalerData(
90
+ name="R-ESRGAN General 4xV3",
91
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
92
+ scale=4,
93
+ upscaler=scaler,
94
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
95
+ ),
96
+ UpscalerData(
97
+ name="R-ESRGAN General WDN 4xV3",
98
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
99
+ scale=4,
100
+ upscaler=scaler,
101
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
102
+ ),
103
+ UpscalerData(
104
+ name="R-ESRGAN AnimeVideo",
105
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
106
+ scale=4,
107
+ upscaler=scaler,
108
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
109
+ ),
110
+ UpscalerData(
111
+ name="R-ESRGAN 4x+",
112
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
113
+ scale=4,
114
+ upscaler=scaler,
115
+ model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
116
+ ),
117
+ UpscalerData(
118
+ name="R-ESRGAN 4x+ Anime6B",
119
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
120
+ scale=4,
121
+ upscaler=scaler,
122
+ model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
123
+ ),
124
+ UpscalerData(
125
+ name="R-ESRGAN 2x+",
126
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
127
+ scale=2,
128
+ upscaler=scaler,
129
+ model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
130
+ ),
131
+ ]
132
+ return models
133
+ except Exception as e:
134
+ print("Error making Real-ESRGAN models list:", file=sys.stderr)
135
+ print(traceback.format_exc(), file=sys.stderr)
modules/safety.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
3
+ from transformers import AutoFeatureExtractor
4
+ from PIL import Image
5
+
6
+ import modules.shared as shared
7
+
8
+ safety_model_id = "CompVis/stable-diffusion-safety-checker"
9
+ safety_feature_extractor = None
10
+ safety_checker = None
11
+
12
+ def numpy_to_pil(images):
13
+ """
14
+ Convert a numpy image or a batch of images to a PIL image.
15
+ """
16
+ if images.ndim == 3:
17
+ images = images[None, ...]
18
+ images = (images * 255).round().astype("uint8")
19
+ pil_images = [Image.fromarray(image) for image in images]
20
+
21
+ return pil_images
22
+
23
+ # check and replace nsfw content
24
+ def check_safety(x_image):
25
+ global safety_feature_extractor, safety_checker
26
+
27
+ if safety_feature_extractor is None:
28
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
29
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
30
+
31
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
32
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
33
+
34
+ return x_checked_image, has_nsfw_concept
35
+
36
+
37
+ def censor_batch(x):
38
+ x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
39
+ x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
40
+ x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
41
+
42
+ return x
modules/scripts.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import modules.ui as ui
6
+ import gradio as gr
7
+
8
+ from modules.processing import StableDiffusionProcessing
9
+ from modules import shared
10
+
11
+ class Script:
12
+ filename = None
13
+ args_from = None
14
+ args_to = None
15
+
16
+ # The title of the script. This is what will be displayed in the dropdown menu.
17
+ def title(self):
18
+ raise NotImplementedError()
19
+
20
+ # How the script is displayed in the UI. See https://gradio.app/docs/#components
21
+ # for the different UI components you can use and how to create them.
22
+ # Most UI components can return a value, such as a boolean for a checkbox.
23
+ # The returned values are passed to the run method as parameters.
24
+ def ui(self, is_img2img):
25
+ pass
26
+
27
+ # Determines when the script should be shown in the dropdown menu via the
28
+ # returned value. As an example:
29
+ # is_img2img is True if the current tab is img2img, and False if it is txt2img.
30
+ # Thus, return is_img2img to only show the script on the img2img tab.
31
+ def show(self, is_img2img):
32
+ return True
33
+
34
+ # This is where the additional processing is implemented. The parameters include
35
+ # self, the model object "p" (a StableDiffusionProcessing class, see
36
+ # processing.py), and the parameters returned by the ui method.
37
+ # Custom functions can be defined here, and additional libraries can be imported
38
+ # to be used in processing. The return value should be a Processed object, which is
39
+ # what is returned by the process_images method.
40
+ def run(self, *args):
41
+ raise NotImplementedError()
42
+
43
+ # The description method is currently unused.
44
+ # To add a description that appears when hovering over the title, amend the "titles"
45
+ # dict in script.js to include the script title (returned by title) as a key, and
46
+ # your description as the value.
47
+ def describe(self):
48
+ return ""
49
+
50
+
51
+ scripts_data = []
52
+
53
+
54
+ def load_scripts(basedir):
55
+ if not os.path.exists(basedir):
56
+ return
57
+
58
+ for filename in sorted(os.listdir(basedir)):
59
+ path = os.path.join(basedir, filename)
60
+
61
+ if not os.path.isfile(path):
62
+ continue
63
+
64
+ try:
65
+ with open(path, "r", encoding="utf8") as file:
66
+ text = file.read()
67
+
68
+ from types import ModuleType
69
+ compiled = compile(text, path, 'exec')
70
+ module = ModuleType(filename)
71
+ exec(compiled, module.__dict__)
72
+
73
+ for key, script_class in module.__dict__.items():
74
+ if type(script_class) == type and issubclass(script_class, Script):
75
+ scripts_data.append((script_class, path))
76
+
77
+ except Exception:
78
+ print(f"Error loading script: {filename}", file=sys.stderr)
79
+ print(traceback.format_exc(), file=sys.stderr)
80
+
81
+
82
+ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
83
+ try:
84
+ res = func(*args, **kwargs)
85
+ return res
86
+ except Exception:
87
+ print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
88
+ print(traceback.format_exc(), file=sys.stderr)
89
+
90
+ return default
91
+
92
+
93
+ class ScriptRunner:
94
+ def __init__(self):
95
+ self.scripts = []
96
+
97
+ def setup_ui(self, is_img2img):
98
+ for script_class, path in scripts_data:
99
+ script = script_class()
100
+ script.filename = path
101
+
102
+ if not script.show(is_img2img):
103
+ continue
104
+
105
+ self.scripts.append(script)
106
+
107
+ titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
108
+
109
+ dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index")
110
+ inputs = [dropdown]
111
+
112
+ for script in self.scripts:
113
+ script.args_from = len(inputs)
114
+ script.args_to = len(inputs)
115
+
116
+ controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
117
+
118
+ if controls is None:
119
+ continue
120
+
121
+ for control in controls:
122
+ control.custom_script_source = os.path.basename(script.filename)
123
+ control.visible = False
124
+
125
+ inputs += controls
126
+ script.args_to = len(inputs)
127
+
128
+ def select_script(script_index):
129
+ if 0 < script_index <= len(self.scripts):
130
+ script = self.scripts[script_index-1]
131
+ args_from = script.args_from
132
+ args_to = script.args_to
133
+ else:
134
+ args_from = 0
135
+ args_to = 0
136
+
137
+ return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
138
+
139
+ dropdown.change(
140
+ fn=select_script,
141
+ inputs=[dropdown],
142
+ outputs=inputs
143
+ )
144
+
145
+ return inputs
146
+
147
+ def run(self, p: StableDiffusionProcessing, *args):
148
+ script_index = args[0]
149
+
150
+ if script_index == 0:
151
+ return None
152
+
153
+ script = self.scripts[script_index-1]
154
+
155
+ if script is None:
156
+ return None
157
+
158
+ script_args = args[script.args_from:script.args_to]
159
+ processed = script.run(p, *script_args)
160
+
161
+ shared.total_tqdm.clear()
162
+
163
+ return processed
164
+
165
+ def reload_sources(self):
166
+ for si, script in list(enumerate(self.scripts)):
167
+ with open(script.filename, "r", encoding="utf8") as file:
168
+ args_from = script.args_from
169
+ args_to = script.args_to
170
+ filename = script.filename
171
+ text = file.read()
172
+
173
+ from types import ModuleType
174
+
175
+ compiled = compile(text, filename, 'exec')
176
+ module = ModuleType(script.filename)
177
+ exec(compiled, module.__dict__)
178
+
179
+ for key, script_class in module.__dict__.items():
180
+ if type(script_class) == type and issubclass(script_class, Script):
181
+ self.scripts[si] = script_class()
182
+ self.scripts[si].filename = filename
183
+ self.scripts[si].args_from = args_from
184
+ self.scripts[si].args_to = args_to
185
+
186
+ scripts_txt2img = ScriptRunner()
187
+ scripts_img2img = ScriptRunner()
188
+
189
+ def reload_script_body_only():
190
+ scripts_txt2img.reload_sources()
191
+ scripts_img2img.reload_sources()
192
+
193
+
194
+ def reload_scripts(basedir):
195
+ global scripts_txt2img, scripts_img2img
196
+
197
+ scripts_data.clear()
198
+ load_scripts(basedir)
199
+
200
+ scripts_txt2img = ScriptRunner()
201
+ scripts_img2img = ScriptRunner()
modules/scunet_model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import sys
3
+ import traceback
4
+
5
+ import PIL.Image
6
+ import numpy as np
7
+ import torch
8
+ from basicsr.utils.download_util import load_file_from_url
9
+
10
+ import modules.upscaler
11
+ from modules import devices, modelloader
12
+ from modules.paths import models_path
13
+ from modules.scunet_model_arch import SCUNet as net
14
+
15
+
16
+ class UpscalerScuNET(modules.upscaler.Upscaler):
17
+ def __init__(self, dirname):
18
+ self.name = "ScuNET"
19
+ self.model_path = os.path.join(models_path, self.name)
20
+ self.model_name = "ScuNET GAN"
21
+ self.model_name2 = "ScuNET PSNR"
22
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
23
+ self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
24
+ self.user_path = dirname
25
+ super().__init__()
26
+ model_paths = self.find_models(ext_filter=[".pth"])
27
+ scalers = []
28
+ add_model2 = True
29
+ for file in model_paths:
30
+ if "http" in file:
31
+ name = self.model_name
32
+ else:
33
+ name = modelloader.friendly_name(file)
34
+ if name == self.model_name2 or file == self.model_url2:
35
+ add_model2 = False
36
+ try:
37
+ scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
38
+ scalers.append(scaler_data)
39
+ except Exception:
40
+ print(f"Error loading ScuNET model: {file}", file=sys.stderr)
41
+ print(traceback.format_exc(), file=sys.stderr)
42
+ if add_model2:
43
+ scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
44
+ scalers.append(scaler_data2)
45
+ self.scalers = scalers
46
+
47
+ def do_upscale(self, img: PIL.Image, selected_file):
48
+ torch.cuda.empty_cache()
49
+
50
+ model = self.load_model(selected_file)
51
+ if model is None:
52
+ return img
53
+
54
+ device = devices.device_scunet
55
+ img = np.array(img)
56
+ img = img[:, :, ::-1]
57
+ img = np.moveaxis(img, 2, 0) / 255
58
+ img = torch.from_numpy(img).float()
59
+ img = img.unsqueeze(0).to(device)
60
+
61
+ img = img.to(device)
62
+ with torch.no_grad():
63
+ output = model(img)
64
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
65
+ output = 255. * np.moveaxis(output, 0, 2)
66
+ output = output.astype(np.uint8)
67
+ output = output[:, :, ::-1]
68
+ torch.cuda.empty_cache()
69
+ return PIL.Image.fromarray(output, 'RGB')
70
+
71
+ def load_model(self, path: str):
72
+ device = devices.device_scunet
73
+ if "http" in path:
74
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
75
+ progress=True)
76
+ else:
77
+ filename = path
78
+ if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
79
+ print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
80
+ return None
81
+
82
+ model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
83
+ model.load_state_dict(torch.load(filename), strict=True)
84
+ model.eval()
85
+ for k, v in model.named_parameters():
86
+ v.requires_grad = False
87
+ model = model.to(device)
88
+
89
+ return model
90
+
modules/scunet_model_arch.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ from timm.models.layers import trunc_normal_, DropPath
8
+
9
+
10
+ class WMSA(nn.Module):
11
+ """ Self-attention module in Swin Transformer
12
+ """
13
+
14
+ def __init__(self, input_dim, output_dim, head_dim, window_size, type):
15
+ super(WMSA, self).__init__()
16
+ self.input_dim = input_dim
17
+ self.output_dim = output_dim
18
+ self.head_dim = head_dim
19
+ self.scale = self.head_dim ** -0.5
20
+ self.n_heads = input_dim // head_dim
21
+ self.window_size = window_size
22
+ self.type = type
23
+ self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
24
+
25
+ self.relative_position_params = nn.Parameter(
26
+ torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
27
+
28
+ self.linear = nn.Linear(self.input_dim, self.output_dim)
29
+
30
+ trunc_normal_(self.relative_position_params, std=.02)
31
+ self.relative_position_params = torch.nn.Parameter(
32
+ self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
33
+ 2).transpose(
34
+ 0, 1))
35
+
36
+ def generate_mask(self, h, w, p, shift):
37
+ """ generating the mask of SW-MSA
38
+ Args:
39
+ shift: shift parameters in CyclicShift.
40
+ Returns:
41
+ attn_mask: should be (1 1 w p p),
42
+ """
43
+ # supporting sqaure.
44
+ attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
45
+ if self.type == 'W':
46
+ return attn_mask
47
+
48
+ s = p - shift
49
+ attn_mask[-1, :, :s, :, s:, :] = True
50
+ attn_mask[-1, :, s:, :, :s, :] = True
51
+ attn_mask[:, -1, :, :s, :, s:] = True
52
+ attn_mask[:, -1, :, s:, :, :s] = True
53
+ attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
54
+ return attn_mask
55
+
56
+ def forward(self, x):
57
+ """ Forward pass of Window Multi-head Self-attention module.
58
+ Args:
59
+ x: input tensor with shape of [b h w c];
60
+ attn_mask: attention mask, fill -inf where the value is True;
61
+ Returns:
62
+ output: tensor shape [b h w c]
63
+ """
64
+ if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
65
+ x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
66
+ h_windows = x.size(1)
67
+ w_windows = x.size(2)
68
+ # sqaure validation
69
+ # assert h_windows == w_windows
70
+
71
+ x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
72
+ qkv = self.embedding_layer(x)
73
+ q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
74
+ sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
75
+ # Adding learnable relative embedding
76
+ sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
77
+ # Using Attn Mask to distinguish different subwindows.
78
+ if self.type != 'W':
79
+ attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
80
+ sim = sim.masked_fill_(attn_mask, float("-inf"))
81
+
82
+ probs = nn.functional.softmax(sim, dim=-1)
83
+ output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
84
+ output = rearrange(output, 'h b w p c -> b w p (h c)')
85
+ output = self.linear(output)
86
+ output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
87
+
88
+ if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
89
+ dims=(1, 2))
90
+ return output
91
+
92
+ def relative_embedding(self):
93
+ cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
94
+ relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
95
+ # negative is allowed
96
+ return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
97
+
98
+
99
+ class Block(nn.Module):
100
+ def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
101
+ """ SwinTransformer Block
102
+ """
103
+ super(Block, self).__init__()
104
+ self.input_dim = input_dim
105
+ self.output_dim = output_dim
106
+ assert type in ['W', 'SW']
107
+ self.type = type
108
+ if input_resolution <= window_size:
109
+ self.type = 'W'
110
+
111
+ self.ln1 = nn.LayerNorm(input_dim)
112
+ self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
113
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
114
+ self.ln2 = nn.LayerNorm(input_dim)
115
+ self.mlp = nn.Sequential(
116
+ nn.Linear(input_dim, 4 * input_dim),
117
+ nn.GELU(),
118
+ nn.Linear(4 * input_dim, output_dim),
119
+ )
120
+
121
+ def forward(self, x):
122
+ x = x + self.drop_path(self.msa(self.ln1(x)))
123
+ x = x + self.drop_path(self.mlp(self.ln2(x)))
124
+ return x
125
+
126
+
127
+ class ConvTransBlock(nn.Module):
128
+ def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
129
+ """ SwinTransformer and Conv Block
130
+ """
131
+ super(ConvTransBlock, self).__init__()
132
+ self.conv_dim = conv_dim
133
+ self.trans_dim = trans_dim
134
+ self.head_dim = head_dim
135
+ self.window_size = window_size
136
+ self.drop_path = drop_path
137
+ self.type = type
138
+ self.input_resolution = input_resolution
139
+
140
+ assert self.type in ['W', 'SW']
141
+ if self.input_resolution <= self.window_size:
142
+ self.type = 'W'
143
+
144
+ self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
145
+ self.type, self.input_resolution)
146
+ self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
147
+ self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
148
+
149
+ self.conv_block = nn.Sequential(
150
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
151
+ nn.ReLU(True),
152
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
153
+ )
154
+
155
+ def forward(self, x):
156
+ conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
157
+ conv_x = self.conv_block(conv_x) + conv_x
158
+ trans_x = Rearrange('b c h w -> b h w c')(trans_x)
159
+ trans_x = self.trans_block(trans_x)
160
+ trans_x = Rearrange('b h w c -> b c h w')(trans_x)
161
+ res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
162
+ x = x + res
163
+
164
+ return x
165
+
166
+
167
+ class SCUNet(nn.Module):
168
+ # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
169
+ def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
170
+ super(SCUNet, self).__init__()
171
+ if config is None:
172
+ config = [2, 2, 2, 2, 2, 2, 2]
173
+ self.config = config
174
+ self.dim = dim
175
+ self.head_dim = 32
176
+ self.window_size = 8
177
+
178
+ # drop path rate for each layer
179
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
180
+
181
+ self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
182
+
183
+ begin = 0
184
+ self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
185
+ 'W' if not i % 2 else 'SW', input_resolution)
186
+ for i in range(config[0])] + \
187
+ [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
188
+
189
+ begin += config[0]
190
+ self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
191
+ 'W' if not i % 2 else 'SW', input_resolution // 2)
192
+ for i in range(config[1])] + \
193
+ [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
194
+
195
+ begin += config[1]
196
+ self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
197
+ 'W' if not i % 2 else 'SW', input_resolution // 4)
198
+ for i in range(config[2])] + \
199
+ [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
200
+
201
+ begin += config[2]
202
+ self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
203
+ 'W' if not i % 2 else 'SW', input_resolution // 8)
204
+ for i in range(config[3])]
205
+
206
+ begin += config[3]
207
+ self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
208
+ [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
209
+ 'W' if not i % 2 else 'SW', input_resolution // 4)
210
+ for i in range(config[4])]
211
+
212
+ begin += config[4]
213
+ self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
214
+ [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
215
+ 'W' if not i % 2 else 'SW', input_resolution // 2)
216
+ for i in range(config[5])]
217
+
218
+ begin += config[5]
219
+ self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
220
+ [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
221
+ 'W' if not i % 2 else 'SW', input_resolution)
222
+ for i in range(config[6])]
223
+
224
+ self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
225
+
226
+ self.m_head = nn.Sequential(*self.m_head)
227
+ self.m_down1 = nn.Sequential(*self.m_down1)
228
+ self.m_down2 = nn.Sequential(*self.m_down2)
229
+ self.m_down3 = nn.Sequential(*self.m_down3)
230
+ self.m_body = nn.Sequential(*self.m_body)
231
+ self.m_up3 = nn.Sequential(*self.m_up3)
232
+ self.m_up2 = nn.Sequential(*self.m_up2)
233
+ self.m_up1 = nn.Sequential(*self.m_up1)
234
+ self.m_tail = nn.Sequential(*self.m_tail)
235
+ # self.apply(self._init_weights)
236
+
237
+ def forward(self, x0):
238
+
239
+ h, w = x0.size()[-2:]
240
+ paddingBottom = int(np.ceil(h / 64) * 64 - h)
241
+ paddingRight = int(np.ceil(w / 64) * 64 - w)
242
+ x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
243
+
244
+ x1 = self.m_head(x0)
245
+ x2 = self.m_down1(x1)
246
+ x3 = self.m_down2(x2)
247
+ x4 = self.m_down3(x3)
248
+ x = self.m_body(x4)
249
+ x = self.m_up3(x + x4)
250
+ x = self.m_up2(x + x3)
251
+ x = self.m_up1(x + x2)
252
+ x = self.m_tail(x + x1)
253
+
254
+ x = x[..., :h, :w]
255
+
256
+ return x
257
+
258
+ def _init_weights(self, m):
259
+ if isinstance(m, nn.Linear):
260
+ trunc_normal_(m.weight, std=.02)
261
+ if m.bias is not None:
262
+ nn.init.constant_(m.bias, 0)
263
+ elif isinstance(m, nn.LayerNorm):
264
+ nn.init.constant_(m.bias, 0)
265
+ nn.init.constant_(m.weight, 1.0)
modules/sd_hijack.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import sys
4
+ import traceback
5
+ import torch
6
+ import numpy as np
7
+ from torch import einsum
8
+ from torch.nn.functional import silu
9
+
10
+ import modules.textual_inversion.textual_inversion
11
+ from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
12
+ from modules.shared import opts, device, cmd_opts
13
+
14
+ import ldm.modules.attention
15
+ import ldm.modules.diffusionmodules.model
16
+
17
+ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
18
+ diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
19
+ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
20
+
21
+
22
+ def apply_optimizations():
23
+ undo_optimizations()
24
+
25
+ ldm.modules.diffusionmodules.model.nonlinearity = silu
26
+
27
+ if cmd_opts.opt_split_attention_v1:
28
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
29
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
30
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
31
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
32
+
33
+
34
+ def undo_optimizations():
35
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
36
+ ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
37
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
38
+
39
+
40
+ class StableDiffusionModelHijack:
41
+ fixes = None
42
+ comments = []
43
+ layers = None
44
+ circular_enabled = False
45
+ clip = None
46
+
47
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
48
+
49
+ def hijack(self, m):
50
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
51
+
52
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
53
+ m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
54
+
55
+ self.clip = m.cond_stage_model
56
+
57
+ apply_optimizations()
58
+
59
+ def flatten(el):
60
+ flattened = [flatten(children) for children in el.children()]
61
+ res = [el]
62
+ for c in flattened:
63
+ res += c
64
+ return res
65
+
66
+ self.layers = flatten(m)
67
+
68
+ def undo_hijack(self, m):
69
+ if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
70
+ m.cond_stage_model = m.cond_stage_model.wrapped
71
+
72
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
73
+ if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
74
+ model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
75
+
76
+ def apply_circular(self, enable):
77
+ if self.circular_enabled == enable:
78
+ return
79
+
80
+ self.circular_enabled = enable
81
+
82
+ for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
83
+ layer.padding_mode = 'circular' if enable else 'zeros'
84
+
85
+ def tokenize(self, text):
86
+ max_length = self.clip.max_length - 2
87
+ _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
88
+ return remade_batch_tokens[0], token_count, max_length
89
+
90
+
91
+ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
92
+ def __init__(self, wrapped, hijack):
93
+ super().__init__()
94
+ self.wrapped = wrapped
95
+ self.hijack: StableDiffusionModelHijack = hijack
96
+ self.tokenizer = wrapped.tokenizer
97
+ self.max_length = wrapped.max_length
98
+ self.token_mults = {}
99
+
100
+ tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
101
+ for text, ident in tokens_with_parens:
102
+ mult = 1.0
103
+ for c in text:
104
+ if c == '[':
105
+ mult /= 1.1
106
+ if c == ']':
107
+ mult *= 1.1
108
+ if c == '(':
109
+ mult *= 1.1
110
+ if c == ')':
111
+ mult /= 1.1
112
+
113
+ if mult != 1.0:
114
+ self.token_mults[ident] = mult
115
+
116
+ def tokenize_line(self, line, used_custom_terms, hijack_comments):
117
+ id_start = self.wrapped.tokenizer.bos_token_id
118
+ id_end = self.wrapped.tokenizer.eos_token_id
119
+ maxlen = self.wrapped.max_length
120
+
121
+ if opts.enable_emphasis:
122
+ parsed = prompt_parser.parse_prompt_attention(line)
123
+ else:
124
+ parsed = [[line, 1.0]]
125
+
126
+ tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
127
+
128
+ fixes = []
129
+ remade_tokens = []
130
+ multipliers = []
131
+
132
+ for tokens, (text, weight) in zip(tokenized, parsed):
133
+ i = 0
134
+ while i < len(tokens):
135
+ token = tokens[i]
136
+
137
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
138
+
139
+ if embedding is None:
140
+ remade_tokens.append(token)
141
+ multipliers.append(weight)
142
+ i += 1
143
+ else:
144
+ emb_len = int(embedding.vec.shape[0])
145
+ fixes.append((len(remade_tokens), embedding))
146
+ remade_tokens += [0] * emb_len
147
+ multipliers += [weight] * emb_len
148
+ used_custom_terms.append((embedding.name, embedding.checksum()))
149
+ i += embedding_length_in_tokens
150
+
151
+ if len(remade_tokens) > maxlen - 2:
152
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
153
+ ovf = remade_tokens[maxlen - 2:]
154
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
155
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
156
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
157
+
158
+ token_count = len(remade_tokens)
159
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
160
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
161
+
162
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
163
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
164
+
165
+ return remade_tokens, fixes, multipliers, token_count
166
+
167
+ def process_text(self, texts):
168
+ used_custom_terms = []
169
+ remade_batch_tokens = []
170
+ hijack_comments = []
171
+ hijack_fixes = []
172
+ token_count = 0
173
+
174
+ cache = {}
175
+ batch_multipliers = []
176
+ for line in texts:
177
+ if line in cache:
178
+ remade_tokens, fixes, multipliers = cache[line]
179
+ else:
180
+ remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
181
+
182
+ cache[line] = (remade_tokens, fixes, multipliers)
183
+
184
+ remade_batch_tokens.append(remade_tokens)
185
+ hijack_fixes.append(fixes)
186
+ batch_multipliers.append(multipliers)
187
+
188
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
189
+
190
+
191
+ def process_text_old(self, text):
192
+ id_start = self.wrapped.tokenizer.bos_token_id
193
+ id_end = self.wrapped.tokenizer.eos_token_id
194
+ maxlen = self.wrapped.max_length
195
+ used_custom_terms = []
196
+ remade_batch_tokens = []
197
+ overflowing_words = []
198
+ hijack_comments = []
199
+ hijack_fixes = []
200
+ token_count = 0
201
+
202
+ cache = {}
203
+ batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
204
+ batch_multipliers = []
205
+ for tokens in batch_tokens:
206
+ tuple_tokens = tuple(tokens)
207
+
208
+ if tuple_tokens in cache:
209
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
210
+ else:
211
+ fixes = []
212
+ remade_tokens = []
213
+ multipliers = []
214
+ mult = 1.0
215
+
216
+ i = 0
217
+ while i < len(tokens):
218
+ token = tokens[i]
219
+
220
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
221
+
222
+ mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
223
+ if mult_change is not None:
224
+ mult *= mult_change
225
+ i += 1
226
+ elif embedding is None:
227
+ remade_tokens.append(token)
228
+ multipliers.append(mult)
229
+ i += 1
230
+ else:
231
+ emb_len = int(embedding.vec.shape[0])
232
+ fixes.append((len(remade_tokens), embedding))
233
+ remade_tokens += [0] * emb_len
234
+ multipliers += [mult] * emb_len
235
+ used_custom_terms.append((embedding.name, embedding.checksum()))
236
+ i += embedding_length_in_tokens
237
+
238
+ if len(remade_tokens) > maxlen - 2:
239
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
240
+ ovf = remade_tokens[maxlen - 2:]
241
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
242
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
243
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
244
+
245
+ token_count = len(remade_tokens)
246
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
247
+ remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
248
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
249
+
250
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
251
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
252
+
253
+ remade_batch_tokens.append(remade_tokens)
254
+ hijack_fixes.append(fixes)
255
+ batch_multipliers.append(multipliers)
256
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
257
+
258
+ def forward(self, text):
259
+
260
+ if opts.use_old_emphasis_implementation:
261
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
262
+ else:
263
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
264
+
265
+ self.hijack.fixes = hijack_fixes
266
+ self.hijack.comments = hijack_comments
267
+
268
+ if len(used_custom_terms) > 0:
269
+ self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
270
+
271
+ tokens = torch.asarray(remade_batch_tokens).to(device)
272
+ outputs = self.wrapped.transformer(input_ids=tokens)
273
+ z = outputs.last_hidden_state
274
+
275
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
276
+ batch_multipliers = torch.asarray(batch_multipliers).to(device)
277
+ original_mean = z.mean()
278
+ z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
279
+ new_mean = z.mean()
280
+ z *= original_mean / new_mean
281
+
282
+ return z
283
+
284
+
285
+ class EmbeddingsWithFixes(torch.nn.Module):
286
+ def __init__(self, wrapped, embeddings):
287
+ super().__init__()
288
+ self.wrapped = wrapped
289
+ self.embeddings = embeddings
290
+
291
+ def forward(self, input_ids):
292
+ batch_fixes = self.embeddings.fixes
293
+ self.embeddings.fixes = None
294
+
295
+ inputs_embeds = self.wrapped(input_ids)
296
+
297
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
298
+ return inputs_embeds
299
+
300
+ vecs = []
301
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
302
+ for offset, embedding in fixes:
303
+ emb = embedding.vec
304
+ emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
305
+ tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
306
+
307
+ vecs.append(tensor)
308
+
309
+ return torch.stack(vecs)
310
+
311
+
312
+ def add_circular_option_to_conv_2d():
313
+ conv2d_constructor = torch.nn.Conv2d.__init__
314
+
315
+ def conv2d_constructor_circular(self, *args, **kwargs):
316
+ return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
317
+
318
+ torch.nn.Conv2d.__init__ = conv2d_constructor_circular
319
+
320
+
321
+ model_hijack = StableDiffusionModelHijack()
modules/sd_hijack_optimizations.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import einsum
4
+
5
+ from ldm.util import default
6
+ from einops import rearrange
7
+
8
+ from modules import shared
9
+
10
+
11
+ # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
12
+ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
13
+ h = self.heads
14
+
15
+ q = self.to_q(x)
16
+ context = default(context, x)
17
+ k = self.to_k(context)
18
+ v = self.to_v(context)
19
+ del context, x
20
+
21
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
22
+
23
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
24
+ for i in range(0, q.shape[0], 2):
25
+ end = i + 2
26
+ s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
27
+ s1 *= self.scale
28
+
29
+ s2 = s1.softmax(dim=-1)
30
+ del s1
31
+
32
+ r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
33
+ del s2
34
+
35
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
36
+ del r1
37
+
38
+ return self.to_out(r2)
39
+
40
+
41
+ # taken from https://github.com/Doggettx/stable-diffusion
42
+ def split_cross_attention_forward(self, x, context=None, mask=None):
43
+ h = self.heads
44
+
45
+ q_in = self.to_q(x)
46
+ context = default(context, x)
47
+
48
+ hypernetwork = shared.selected_hypernetwork()
49
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
50
+
51
+ if hypernetwork_layers is not None:
52
+ k_in = self.to_k(hypernetwork_layers[0](context))
53
+ v_in = self.to_v(hypernetwork_layers[1](context))
54
+ else:
55
+ k_in = self.to_k(context)
56
+ v_in = self.to_v(context)
57
+
58
+ k_in *= self.scale
59
+
60
+ del context, x
61
+
62
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
63
+ del q_in, k_in, v_in
64
+
65
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
66
+
67
+ stats = torch.cuda.memory_stats(q.device)
68
+ mem_active = stats['active_bytes.all.current']
69
+ mem_reserved = stats['reserved_bytes.all.current']
70
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
71
+ mem_free_torch = mem_reserved - mem_active
72
+ mem_free_total = mem_free_cuda + mem_free_torch
73
+
74
+ gb = 1024 ** 3
75
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
76
+ modifier = 3 if q.element_size() == 2 else 2.5
77
+ mem_required = tensor_size * modifier
78
+ steps = 1
79
+
80
+ if mem_required > mem_free_total:
81
+ steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
82
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
83
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
84
+
85
+ if steps > 64:
86
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
87
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
88
+ f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
89
+
90
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
91
+ for i in range(0, q.shape[1], slice_size):
92
+ end = i + slice_size
93
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
94
+
95
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
96
+ del s1
97
+
98
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
99
+ del s2
100
+
101
+ del q, k, v
102
+
103
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
104
+ del r1
105
+
106
+ return self.to_out(r2)
107
+
108
+ def cross_attention_attnblock_forward(self, x):
109
+ h_ = x
110
+ h_ = self.norm(h_)
111
+ q1 = self.q(h_)
112
+ k1 = self.k(h_)
113
+ v = self.v(h_)
114
+
115
+ # compute attention
116
+ b, c, h, w = q1.shape
117
+
118
+ q2 = q1.reshape(b, c, h*w)
119
+ del q1
120
+
121
+ q = q2.permute(0, 2, 1) # b,hw,c
122
+ del q2
123
+
124
+ k = k1.reshape(b, c, h*w) # b,c,hw
125
+ del k1
126
+
127
+ h_ = torch.zeros_like(k, device=q.device)
128
+
129
+ stats = torch.cuda.memory_stats(q.device)
130
+ mem_active = stats['active_bytes.all.current']
131
+ mem_reserved = stats['reserved_bytes.all.current']
132
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
133
+ mem_free_torch = mem_reserved - mem_active
134
+ mem_free_total = mem_free_cuda + mem_free_torch
135
+
136
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
137
+ mem_required = tensor_size * 2.5
138
+ steps = 1
139
+
140
+ if mem_required > mem_free_total:
141
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
142
+
143
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
144
+ for i in range(0, q.shape[1], slice_size):
145
+ end = i + slice_size
146
+
147
+ w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
148
+ w2 = w1 * (int(c)**(-0.5))
149
+ del w1
150
+ w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
151
+ del w2
152
+
153
+ # attend to values
154
+ v1 = v.reshape(b, c, h*w)
155
+ w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
156
+ del w3
157
+
158
+ h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
159
+ del v1, w4
160
+
161
+ h2 = h_.reshape(b, c, h, w)
162
+ del h_
163
+
164
+ h3 = self.proj_out(h2)
165
+ del h2
166
+
167
+ h3 += x
168
+
169
+ return h3
modules/sd_models.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os.path
3
+ import sys
4
+ from collections import namedtuple
5
+ import torch
6
+ from omegaconf import OmegaConf
7
+
8
+
9
+ from ldm.util import instantiate_from_config
10
+
11
+ from modules import shared, modelloader, devices
12
+ from modules.paths import models_path
13
+
14
+ model_dir = "Stable-diffusion"
15
+ model_path = os.path.abspath(os.path.join(models_path, model_dir))
16
+
17
+ CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
18
+ checkpoints_list = {}
19
+
20
+ try:
21
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
22
+
23
+ from transformers import logging
24
+
25
+ logging.set_verbosity_error()
26
+ except Exception:
27
+ pass
28
+
29
+
30
+ def setup_model():
31
+ if not os.path.exists(model_path):
32
+ os.makedirs(model_path)
33
+
34
+ list_models()
35
+
36
+
37
+ def checkpoint_tiles():
38
+ return sorted([x.title for x in checkpoints_list.values()])
39
+
40
+
41
+ def list_models():
42
+ checkpoints_list.clear()
43
+ model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
44
+
45
+ def modeltitle(path, shorthash):
46
+ abspath = os.path.abspath(path)
47
+
48
+ if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
49
+ name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
50
+ elif abspath.startswith(model_path):
51
+ name = abspath.replace(model_path, '')
52
+ else:
53
+ name = os.path.basename(path)
54
+
55
+ if name.startswith("\\") or name.startswith("/"):
56
+ name = name[1:]
57
+
58
+ shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
59
+
60
+ return f'{name} [{shorthash}]', shortname
61
+
62
+ cmd_ckpt = shared.cmd_opts.ckpt
63
+ if os.path.exists(cmd_ckpt):
64
+ h = model_hash(cmd_ckpt)
65
+ title, short_model_name = modeltitle(cmd_ckpt, h)
66
+ checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
67
+ shared.opts.data['sd_model_checkpoint'] = title
68
+ elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
69
+ print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
70
+ for filename in model_list:
71
+ h = model_hash(filename)
72
+ title, short_model_name = modeltitle(filename, h)
73
+ checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
74
+
75
+
76
+ def get_closet_checkpoint_match(searchString):
77
+ applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
78
+ if len(applicable) > 0:
79
+ return applicable[0]
80
+ return None
81
+
82
+
83
+ def model_hash(filename):
84
+ try:
85
+ with open(filename, "rb") as file:
86
+ import hashlib
87
+ m = hashlib.sha256()
88
+
89
+ file.seek(0x100000)
90
+ m.update(file.read(0x10000))
91
+ return m.hexdigest()[0:8]
92
+ except FileNotFoundError:
93
+ return 'NOFILE'
94
+
95
+
96
+ def select_checkpoint():
97
+ model_checkpoint = shared.opts.sd_model_checkpoint
98
+ checkpoint_info = checkpoints_list.get(model_checkpoint, None)
99
+ if checkpoint_info is not None:
100
+ return checkpoint_info
101
+
102
+ if len(checkpoints_list) == 0:
103
+ print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
104
+ if shared.cmd_opts.ckpt is not None:
105
+ print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
106
+ print(f" - directory {model_path}", file=sys.stderr)
107
+ if shared.cmd_opts.ckpt_dir is not None:
108
+ print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
109
+ print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
110
+ exit(1)
111
+
112
+ checkpoint_info = next(iter(checkpoints_list.values()))
113
+ if model_checkpoint is not None:
114
+ print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
115
+
116
+ return checkpoint_info
117
+
118
+
119
+ def load_model_weights(model, checkpoint_file, sd_model_hash):
120
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
121
+
122
+ pl_sd = torch.load(checkpoint_file, map_location="cpu")
123
+ if "global_step" in pl_sd:
124
+ print(f"Global Step: {pl_sd['global_step']}")
125
+ sd = pl_sd["state_dict"]
126
+
127
+ model.load_state_dict(sd, strict=False)
128
+
129
+ if shared.cmd_opts.opt_channelslast:
130
+ model.to(memory_format=torch.channels_last)
131
+
132
+ if not shared.cmd_opts.no_half:
133
+ model.half()
134
+
135
+ devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
136
+
137
+ vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
138
+ if os.path.exists(vae_file):
139
+ print(f"Loading VAE weights from: {vae_file}")
140
+ vae_ckpt = torch.load(vae_file, map_location="cpu")
141
+ vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
142
+
143
+ model.first_stage_model.load_state_dict(vae_dict)
144
+
145
+ model.sd_model_hash = sd_model_hash
146
+ model.sd_model_checkpint = checkpoint_file
147
+
148
+
149
+ def load_model():
150
+ from modules import lowvram, sd_hijack
151
+ checkpoint_info = select_checkpoint()
152
+
153
+ sd_config = OmegaConf.load(shared.cmd_opts.config)
154
+ sd_model = instantiate_from_config(sd_config.model)
155
+ load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
156
+
157
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
158
+ lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
159
+ else:
160
+ sd_model.to(shared.device)
161
+
162
+ sd_hijack.model_hijack.hijack(sd_model)
163
+
164
+ sd_model.eval()
165
+
166
+ print(f"Model loaded.")
167
+ return sd_model
168
+
169
+
170
+ def reload_model_weights(sd_model, info=None):
171
+ from modules import lowvram, devices, sd_hijack
172
+ checkpoint_info = info or select_checkpoint()
173
+
174
+ if sd_model.sd_model_checkpint == checkpoint_info.filename:
175
+ return
176
+
177
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
178
+ lowvram.send_everything_to_cpu()
179
+ else:
180
+ sd_model.to(devices.cpu)
181
+
182
+ sd_hijack.model_hijack.undo_hijack(sd_model)
183
+
184
+ load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
185
+
186
+ sd_hijack.model_hijack.hijack(sd_model)
187
+
188
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
189
+ sd_model.to(devices.device)
190
+
191
+ print(f"Weights loaded.")
192
+ return sd_model
modules/sd_samplers.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import numpy as np
3
+ import torch
4
+ import tqdm
5
+ from PIL import Image
6
+ import inspect
7
+ import k_diffusion.sampling
8
+ import ldm.models.diffusion.ddim
9
+ import ldm.models.diffusion.plms
10
+ from modules import prompt_parser
11
+
12
+ from modules.shared import opts, cmd_opts, state
13
+ import modules.shared as shared
14
+
15
+
16
+ SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
17
+
18
+ samplers_k_diffusion = [
19
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
20
+ ('Euler', 'sample_euler', ['k_euler'], {}),
21
+ ('LMS', 'sample_lms', ['k_lms'], {}),
22
+ ('Heun', 'sample_heun', ['k_heun'], {}),
23
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
24
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
25
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
26
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
27
+ ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
28
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
29
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
30
+ ]
31
+
32
+ samplers_data_k_diffusion = [
33
+ SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
34
+ for label, funcname, aliases, options in samplers_k_diffusion
35
+ if hasattr(k_diffusion.sampling, funcname)
36
+ ]
37
+
38
+ all_samplers = [
39
+ *samplers_data_k_diffusion,
40
+ SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
41
+ SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
42
+ ]
43
+
44
+ samplers = []
45
+ samplers_for_img2img = []
46
+
47
+
48
+ def create_sampler_with_index(list_of_configs, index, model):
49
+ config = list_of_configs[index]
50
+ sampler = config.constructor(model)
51
+ sampler.config = config
52
+
53
+ return sampler
54
+
55
+
56
+ def set_samplers():
57
+ global samplers, samplers_for_img2img
58
+
59
+ hidden = set(opts.hide_samplers)
60
+ hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive'])
61
+
62
+ samplers = [x for x in all_samplers if x.name not in hidden]
63
+ samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
64
+
65
+
66
+ set_samplers()
67
+
68
+ sampler_extra_params = {
69
+ 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
70
+ 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
71
+ 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
72
+ }
73
+
74
+ def setup_img2img_steps(p, steps=None):
75
+ if opts.img2img_fix_steps or steps is not None:
76
+ steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
77
+ t_enc = p.steps - 1
78
+ else:
79
+ steps = p.steps
80
+ t_enc = int(min(p.denoising_strength, 0.999) * steps)
81
+
82
+ return steps, t_enc
83
+
84
+
85
+ def sample_to_image(samples):
86
+ x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
87
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
88
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
89
+ x_sample = x_sample.astype(np.uint8)
90
+ return Image.fromarray(x_sample)
91
+
92
+
93
+ def store_latent(decoded):
94
+ state.current_latent = decoded
95
+
96
+ if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
97
+ if not shared.parallel_processing_allowed:
98
+ shared.state.current_image = sample_to_image(decoded)
99
+
100
+
101
+
102
+ def extended_tdqm(sequence, *args, desc=None, **kwargs):
103
+ state.sampling_steps = len(sequence)
104
+ state.sampling_step = 0
105
+
106
+ seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
107
+
108
+ for x in seq:
109
+ if state.interrupted:
110
+ break
111
+
112
+ yield x
113
+
114
+ state.sampling_step += 1
115
+ shared.total_tqdm.update()
116
+
117
+
118
+ ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
119
+ ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
120
+
121
+
122
+ class VanillaStableDiffusionSampler:
123
+ def __init__(self, constructor, sd_model):
124
+ self.sampler = constructor(sd_model)
125
+ self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
126
+ self.mask = None
127
+ self.nmask = None
128
+ self.init_latent = None
129
+ self.sampler_noises = None
130
+ self.step = 0
131
+ self.eta = None
132
+ self.default_eta = 0.0
133
+ self.config = None
134
+
135
+ def number_of_needed_noises(self, p):
136
+ return 0
137
+
138
+ def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
139
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
140
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
141
+
142
+ assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
143
+ cond = tensor
144
+
145
+ if self.mask is not None:
146
+ img_orig = self.sampler.model.q_sample(self.init_latent, ts)
147
+ x_dec = img_orig * self.mask + self.nmask * x_dec
148
+
149
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
150
+
151
+ if self.mask is not None:
152
+ store_latent(self.init_latent * self.mask + self.nmask * res[1])
153
+ else:
154
+ store_latent(res[1])
155
+
156
+ self.step += 1
157
+ return res
158
+
159
+ def initialize(self, p):
160
+ self.eta = p.eta if p.eta is not None else opts.eta_ddim
161
+
162
+ for fieldname in ['p_sample_ddim', 'p_sample_plms']:
163
+ if hasattr(self.sampler, fieldname):
164
+ setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
165
+
166
+ self.mask = p.mask if hasattr(p, 'mask') else None
167
+ self.nmask = p.nmask if hasattr(p, 'nmask') else None
168
+
169
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
170
+ steps, t_enc = setup_img2img_steps(p, steps)
171
+
172
+ self.initialize(p)
173
+
174
+ # existing code fails with cetain step counts, like 9
175
+ try:
176
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
177
+ except Exception:
178
+ self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
179
+
180
+ x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
181
+
182
+ self.init_latent = x
183
+ self.step = 0
184
+
185
+ samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
186
+
187
+ return samples
188
+
189
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
190
+ self.initialize(p)
191
+
192
+ self.init_latent = None
193
+ self.step = 0
194
+
195
+ steps = steps or p.steps
196
+
197
+ # existing code fails with cetin step counts, like 9
198
+ try:
199
+ samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
200
+ except Exception:
201
+ samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
202
+
203
+ return samples_ddim
204
+
205
+
206
+ class CFGDenoiser(torch.nn.Module):
207
+ def __init__(self, model):
208
+ super().__init__()
209
+ self.inner_model = model
210
+ self.mask = None
211
+ self.nmask = None
212
+ self.init_latent = None
213
+ self.step = 0
214
+
215
+ def forward(self, x, sigma, uncond, cond, cond_scale):
216
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
217
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
218
+
219
+ batch_size = len(conds_list)
220
+ repeats = [len(conds_list[i]) for i in range(batch_size)]
221
+
222
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
223
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
224
+ cond_in = torch.cat([tensor, uncond])
225
+
226
+ if shared.batch_cond_uncond:
227
+ x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
228
+ else:
229
+ x_out = torch.zeros_like(x_in)
230
+ for batch_offset in range(0, x_out.shape[0], batch_size):
231
+ a = batch_offset
232
+ b = a + batch_size
233
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
234
+
235
+ denoised_uncond = x_out[-batch_size:]
236
+ denoised = torch.clone(denoised_uncond)
237
+
238
+ for i, conds in enumerate(conds_list):
239
+ for cond_index, weight in conds:
240
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
241
+
242
+ if self.mask is not None:
243
+ denoised = self.init_latent * self.mask + self.nmask * denoised
244
+
245
+ self.step += 1
246
+
247
+ return denoised
248
+
249
+
250
+ def extended_trange(sampler, count, *args, **kwargs):
251
+ state.sampling_steps = count
252
+ state.sampling_step = 0
253
+
254
+ seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
255
+
256
+ for x in seq:
257
+ if state.interrupted:
258
+ break
259
+
260
+ if sampler.stop_at is not None and x > sampler.stop_at:
261
+ break
262
+
263
+ yield x
264
+
265
+ state.sampling_step += 1
266
+ shared.total_tqdm.update()
267
+
268
+
269
+ class TorchHijack:
270
+ def __init__(self, kdiff_sampler):
271
+ self.kdiff_sampler = kdiff_sampler
272
+
273
+ def __getattr__(self, item):
274
+ if item == 'randn_like':
275
+ return self.kdiff_sampler.randn_like
276
+
277
+ if hasattr(torch, item):
278
+ return getattr(torch, item)
279
+
280
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
281
+
282
+
283
+ class KDiffusionSampler:
284
+ def __init__(self, funcname, sd_model):
285
+ self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
286
+ self.funcname = funcname
287
+ self.func = getattr(k_diffusion.sampling, self.funcname)
288
+ self.extra_params = sampler_extra_params.get(funcname, [])
289
+ self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
290
+ self.sampler_noises = None
291
+ self.sampler_noise_index = 0
292
+ self.stop_at = None
293
+ self.eta = None
294
+ self.default_eta = 1.0
295
+ self.config = None
296
+
297
+ def callback_state(self, d):
298
+ store_latent(d["denoised"])
299
+
300
+ def number_of_needed_noises(self, p):
301
+ return p.steps
302
+
303
+ def randn_like(self, x):
304
+ noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
305
+
306
+ if noise is not None and x.shape == noise.shape:
307
+ res = noise
308
+ else:
309
+ res = torch.randn_like(x)
310
+
311
+ self.sampler_noise_index += 1
312
+ return res
313
+
314
+ def initialize(self, p):
315
+ self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
316
+ self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
317
+ self.model_wrap.step = 0
318
+ self.sampler_noise_index = 0
319
+ self.eta = p.eta or opts.eta_ancestral
320
+
321
+ if hasattr(k_diffusion.sampling, 'trange'):
322
+ k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
323
+
324
+ if self.sampler_noises is not None:
325
+ k_diffusion.sampling.torch = TorchHijack(self)
326
+
327
+ extra_params_kwargs = {}
328
+ for param_name in self.extra_params:
329
+ if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
330
+ extra_params_kwargs[param_name] = getattr(p, param_name)
331
+
332
+ if 'eta' in inspect.signature(self.func).parameters:
333
+ extra_params_kwargs['eta'] = self.eta
334
+
335
+ return extra_params_kwargs
336
+
337
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
338
+ steps, t_enc = setup_img2img_steps(p, steps)
339
+
340
+ if p.sampler_noise_scheduler_override:
341
+ sigmas = p.sampler_noise_scheduler_override(steps)
342
+ elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
343
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
344
+ else:
345
+ sigmas = self.model_wrap.get_sigmas(steps)
346
+
347
+ noise = noise * sigmas[steps - t_enc - 1]
348
+ xi = x + noise
349
+
350
+ extra_params_kwargs = self.initialize(p)
351
+
352
+ sigma_sched = sigmas[steps - t_enc - 1:]
353
+
354
+ self.model_wrap_cfg.init_latent = x
355
+
356
+ return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
357
+
358
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
359
+ steps = steps or p.steps
360
+
361
+ if p.sampler_noise_scheduler_override:
362
+ sigmas = p.sampler_noise_scheduler_override(steps)
363
+ elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
364
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
365
+ else:
366
+ sigmas = self.model_wrap.get_sigmas(steps)
367
+
368
+ x = x * sigmas[0]
369
+
370
+ extra_params_kwargs = self.initialize(p)
371
+ if 'sigma_min' in inspect.signature(self.func).parameters:
372
+ extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
373
+ extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
374
+ if 'n' in inspect.signature(self.func).parameters:
375
+ extra_params_kwargs['n'] = steps
376
+ else:
377
+ extra_params_kwargs['sigmas'] = sigmas
378
+ samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
379
+ return samples
380
+
modules/shared.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import sys
6
+
7
+ import gradio as gr
8
+ import tqdm
9
+
10
+ import modules.artists
11
+ import modules.interrogate
12
+ import modules.memmon
13
+ import modules.sd_models
14
+ import modules.styles
15
+ import modules.devices as devices
16
+ from modules import sd_samplers, hypernetwork
17
+ from modules.paths import models_path, script_path, sd_path
18
+
19
+ sd_model_file = os.path.join(script_path, 'model.ckpt')
20
+ default_sd_model_file = sd_model_file
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
23
+ parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
24
+ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
25
+ parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
26
+ parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
27
+ parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
28
+ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
29
+ parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
30
+ parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
31
+ parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
32
+ parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
33
+ parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
34
+ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
35
+ parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
36
+ parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
37
+ parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
38
+ parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
39
+ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
40
+ parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
41
+ parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
42
+ parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
43
+ parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
44
+ parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
45
+ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
46
+ parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
47
+ parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
48
+ parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
49
+ parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
50
+ parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
51
+ parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
52
+ parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
53
+ parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
54
+ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
55
+ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
56
+ parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
57
+ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
58
+ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
59
+ parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
60
+ parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
61
+ parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
62
+ parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
63
+ parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
64
+ parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
65
+
66
+
67
+ cmd_opts = parser.parse_args()
68
+
69
+ devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
70
+ (devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
71
+
72
+ device = devices.device
73
+
74
+ batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
75
+ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
76
+
77
+ config_filename = cmd_opts.ui_settings_file
78
+
79
+ hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks'))
80
+
81
+
82
+ def selected_hypernetwork():
83
+ return hypernetworks.get(opts.sd_hypernetwork, None)
84
+
85
+
86
+ class State:
87
+ interrupted = False
88
+ job = ""
89
+ job_no = 0
90
+ job_count = 0
91
+ job_timestamp = '0'
92
+ sampling_step = 0
93
+ sampling_steps = 0
94
+ current_latent = None
95
+ current_image = None
96
+ current_image_sampling_step = 0
97
+ textinfo = None
98
+
99
+ def interrupt(self):
100
+ self.interrupted = True
101
+
102
+ def nextjob(self):
103
+ self.job_no += 1
104
+ self.sampling_step = 0
105
+ self.current_image_sampling_step = 0
106
+
107
+ def get_job_timestamp(self):
108
+ return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
109
+
110
+
111
+ state = State()
112
+
113
+ artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
114
+
115
+ styles_filename = cmd_opts.styles_file
116
+ prompt_styles = modules.styles.StyleDatabase(styles_filename)
117
+
118
+ interrogator = modules.interrogate.InterrogateModels("interrogate")
119
+
120
+ face_restorers = []
121
+ # This was moved to webui.py with the other model "setup" calls.
122
+ # modules.sd_models.list_models()
123
+
124
+
125
+ def realesrgan_models_names():
126
+ import modules.realesrgan_model
127
+ return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
128
+
129
+
130
+ class OptionInfo:
131
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None):
132
+ self.default = default
133
+ self.label = label
134
+ self.component = component
135
+ self.component_args = component_args
136
+ self.onchange = onchange
137
+ self.section = None
138
+
139
+
140
+ def options_section(section_identifer, options_dict):
141
+ for k, v in options_dict.items():
142
+ v.section = section_identifer
143
+
144
+ return options_dict
145
+
146
+
147
+ hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
148
+
149
+ options_templates = {}
150
+
151
+ options_templates.update(options_section(('saving-images', "Saving images/grids"), {
152
+ "samples_save": OptionInfo(True, "Always save all generated images"),
153
+ "samples_format": OptionInfo('png', 'File format for images'),
154
+ "samples_filename_pattern": OptionInfo("", "Images filename pattern"),
155
+
156
+ "grid_save": OptionInfo(True, "Always save all generated image grids"),
157
+ "grid_format": OptionInfo('png', 'File format for grids'),
158
+ "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
159
+ "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
160
+ "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
161
+
162
+ "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
163
+ "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
164
+ "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
165
+ "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
166
+ "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
167
+
168
+ "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
169
+ "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
170
+ }))
171
+
172
+ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
173
+ "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
174
+ "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
175
+ "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
176
+ "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
177
+ "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
178
+ "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
179
+ "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
180
+ "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
181
+ }))
182
+
183
+ options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
184
+ "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
185
+ "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
186
+ "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
187
+ "directories_filename_pattern": OptionInfo("", "Directory name pattern"),
188
+ "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
189
+ }))
190
+
191
+ options_templates.update(options_section(('upscaling', "Upscaling"), {
192
+ "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
193
+ "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
194
+ "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
195
+ "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
196
+ "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
197
+ "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
198
+ "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
199
+ }))
200
+
201
+ options_templates.update(options_section(('face-restoration', "Face restoration"), {
202
+ "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
203
+ "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
204
+ "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
205
+ }))
206
+
207
+ options_templates.update(options_section(('system', "System"), {
208
+ "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
209
+ "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
210
+ "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
211
+ }))
212
+
213
+ options_templates.update(options_section(('sd', "Stable Diffusion"), {
214
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
215
+ "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
216
+ "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
217
+ "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
218
+ "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
219
+ "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
220
+ "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
221
+ "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
222
+ "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
223
+ "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
224
+ "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
225
+ }))
226
+
227
+ options_templates.update(options_section(('interrogate', "Interrogate Options"), {
228
+ "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
229
+ "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
230
+ "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
231
+ "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
232
+ "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
233
+ "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
234
+ }))
235
+
236
+ options_templates.update(options_section(('ui', "User interface"), {
237
+ "show_progressbar": OptionInfo(True, "Show progressbar"),
238
+ "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
239
+ "return_grid": OptionInfo(True, "Show grid in results for web"),
240
+ "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
241
+ "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
242
+ "font": OptionInfo("", "Font for image grids that have text"),
243
+ "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
244
+ "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
245
+ "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
246
+ }))
247
+
248
+ options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
249
+ "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
250
+ "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
251
+ "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
252
+ "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
253
+ 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
254
+ 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
255
+ 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
256
+ }))
257
+
258
+
259
+ class Options:
260
+ data = None
261
+ data_labels = options_templates
262
+ typemap = {int: float}
263
+
264
+ def __init__(self):
265
+ self.data = {k: v.default for k, v in self.data_labels.items()}
266
+
267
+ def __setattr__(self, key, value):
268
+ if self.data is not None:
269
+ if key in self.data:
270
+ self.data[key] = value
271
+
272
+ return super(Options, self).__setattr__(key, value)
273
+
274
+ def __getattr__(self, item):
275
+ if self.data is not None:
276
+ if item in self.data:
277
+ return self.data[item]
278
+
279
+ if item in self.data_labels:
280
+ return self.data_labels[item].default
281
+
282
+ return super(Options, self).__getattribute__(item)
283
+
284
+ def save(self, filename):
285
+ with open(filename, "w", encoding="utf8") as file:
286
+ json.dump(self.data, file)
287
+
288
+ def same_type(self, x, y):
289
+ if x is None or y is None:
290
+ return True
291
+
292
+ type_x = self.typemap.get(type(x), type(x))
293
+ type_y = self.typemap.get(type(y), type(y))
294
+
295
+ return type_x == type_y
296
+
297
+ def load(self, filename):
298
+ with open(filename, "r", encoding="utf8") as file:
299
+ self.data = json.load(file)
300
+
301
+ bad_settings = 0
302
+ for k, v in self.data.items():
303
+ info = self.data_labels.get(k, None)
304
+ if info is not None and not self.same_type(info.default, v):
305
+ print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
306
+ bad_settings += 1
307
+
308
+ if bad_settings > 0:
309
+ print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
310
+
311
+ def onchange(self, key, func):
312
+ item = self.data_labels.get(key)
313
+ item.onchange = func
314
+
315
+ def dumpjson(self):
316
+ d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
317
+ return json.dumps(d)
318
+
319
+
320
+ opts = Options()
321
+ if os.path.exists(config_filename):
322
+ opts.load(config_filename)
323
+
324
+ sd_upscalers = []
325
+
326
+ sd_model = None
327
+
328
+ progress_print_out = sys.stdout
329
+
330
+
331
+ class TotalTQDM:
332
+ def __init__(self):
333
+ self._tqdm = None
334
+
335
+ def reset(self):
336
+ self._tqdm = tqdm.tqdm(
337
+ desc="Total progress",
338
+ total=state.job_count * state.sampling_steps,
339
+ position=1,
340
+ file=progress_print_out
341
+ )
342
+
343
+ def update(self):
344
+ if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
345
+ return
346
+ if self._tqdm is None:
347
+ self.reset()
348
+ self._tqdm.update()
349
+
350
+ def updateTotal(self, new_total):
351
+ if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
352
+ return
353
+ if self._tqdm is None:
354
+ self.reset()
355
+ self._tqdm.total=new_total
356
+
357
+ def clear(self):
358
+ if self._tqdm is not None:
359
+ self._tqdm.close()
360
+ self._tqdm = None
361
+
362
+
363
+ total_tqdm = TotalTQDM()
364
+
365
+ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
366
+ mem_mon.start()
modules/styles.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
2
+ from __future__ import annotations
3
+
4
+ import csv
5
+ import os
6
+ import os.path
7
+ import typing
8
+ import collections.abc as abc
9
+ import tempfile
10
+ import shutil
11
+
12
+ if typing.TYPE_CHECKING:
13
+ # Only import this when code is being type-checked, it doesn't have any effect at runtime
14
+ from .processing import StableDiffusionProcessing
15
+
16
+
17
+ class PromptStyle(typing.NamedTuple):
18
+ name: str
19
+ prompt: str
20
+ negative_prompt: str
21
+
22
+
23
+ def merge_prompts(style_prompt: str, prompt: str) -> str:
24
+ if "{prompt}" in style_prompt:
25
+ res = style_prompt.replace("{prompt}", prompt)
26
+ else:
27
+ parts = filter(None, (prompt.strip(), style_prompt.strip()))
28
+ res = ", ".join(parts)
29
+
30
+ return res
31
+
32
+
33
+ def apply_styles_to_prompt(prompt, styles):
34
+ for style in styles:
35
+ prompt = merge_prompts(style, prompt)
36
+
37
+ return prompt
38
+
39
+
40
+ class StyleDatabase:
41
+ def __init__(self, path: str):
42
+ self.no_style = PromptStyle("None", "", "")
43
+ self.styles = {"None": self.no_style}
44
+
45
+ if not os.path.exists(path):
46
+ return
47
+
48
+ with open(path, "r", encoding="utf8", newline='') as file:
49
+ reader = csv.DictReader(file)
50
+ for row in reader:
51
+ # Support loading old CSV format with "name, text"-columns
52
+ prompt = row["prompt"] if "prompt" in row else row["text"]
53
+ negative_prompt = row.get("negative_prompt", "")
54
+ self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
55
+
56
+ def get_style_prompts(self, styles):
57
+ return [self.styles.get(x, self.no_style).prompt for x in styles]
58
+
59
+ def get_negative_style_prompts(self, styles):
60
+ return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
61
+
62
+ def apply_styles_to_prompt(self, prompt, styles):
63
+ return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
64
+
65
+ def apply_negative_styles_to_prompt(self, prompt, styles):
66
+ return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
67
+
68
+ def apply_styles(self, p: StableDiffusionProcessing) -> None:
69
+ if isinstance(p.prompt, list):
70
+ p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
71
+ else:
72
+ p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
73
+
74
+ if isinstance(p.negative_prompt, list):
75
+ p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
76
+ else:
77
+ p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
78
+
79
+ def save_styles(self, path: str) -> None:
80
+ # Write to temporary file first, so we don't nuke the file if something goes wrong
81
+ fd, temp_path = tempfile.mkstemp(".csv")
82
+ with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
83
+ # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
84
+ # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
85
+ writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
86
+ writer.writeheader()
87
+ writer.writerows(style._asdict() for k, style in self.styles.items())
88
+
89
+ # Always keep a backup file around
90
+ if os.path.exists(path):
91
+ shutil.move(path, path + ".bak")
92
+ shutil.move(temp_path, path)
modules/swinir_model.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from basicsr.utils.download_util import load_file_from_url
8
+ from tqdm import tqdm
9
+
10
+ from modules import modelloader
11
+ from modules.paths import models_path
12
+ from modules.shared import cmd_opts, opts, device
13
+ from modules.swinir_model_arch import SwinIR as net
14
+ from modules.upscaler import Upscaler, UpscalerData
15
+
16
+ precision_scope = (
17
+ torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
18
+ )
19
+
20
+
21
+ class UpscalerSwinIR(Upscaler):
22
+ def __init__(self, dirname):
23
+ self.name = "SwinIR"
24
+ self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
25
+ "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
26
+ "-L_x4_GAN.pth "
27
+ self.model_name = "SwinIR 4x"
28
+ self.model_path = os.path.join(models_path, self.name)
29
+ self.user_path = dirname
30
+ super().__init__()
31
+ scalers = []
32
+ model_files = self.find_models(ext_filter=[".pt", ".pth"])
33
+ for model in model_files:
34
+ if "http" in model:
35
+ name = self.model_name
36
+ else:
37
+ name = modelloader.friendly_name(model)
38
+ model_data = UpscalerData(name, model, self)
39
+ scalers.append(model_data)
40
+ self.scalers = scalers
41
+
42
+ def do_upscale(self, img, model_file):
43
+ model = self.load_model(model_file)
44
+ if model is None:
45
+ return img
46
+ model = model.to(device)
47
+ img = upscale(img, model)
48
+ try:
49
+ torch.cuda.empty_cache()
50
+ except:
51
+ pass
52
+ return img
53
+
54
+ def load_model(self, path, scale=4):
55
+ if "http" in path:
56
+ dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
57
+ filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
58
+ else:
59
+ filename = path
60
+ if filename is None or not os.path.exists(filename):
61
+ return None
62
+ model = net(
63
+ upscale=scale,
64
+ in_chans=3,
65
+ img_size=64,
66
+ window_size=8,
67
+ img_range=1.0,
68
+ depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
69
+ embed_dim=240,
70
+ num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
71
+ mlp_ratio=2,
72
+ upsampler="nearest+conv",
73
+ resi_connection="3conv",
74
+ )
75
+
76
+ pretrained_model = torch.load(filename)
77
+ model.load_state_dict(pretrained_model["params_ema"], strict=True)
78
+ if not cmd_opts.no_half:
79
+ model = model.half()
80
+ return model
81
+
82
+
83
+ def upscale(
84
+ img,
85
+ model,
86
+ tile=opts.SWIN_tile,
87
+ tile_overlap=opts.SWIN_tile_overlap,
88
+ window_size=8,
89
+ scale=4,
90
+ ):
91
+ img = np.array(img)
92
+ img = img[:, :, ::-1]
93
+ img = np.moveaxis(img, 2, 0) / 255
94
+ img = torch.from_numpy(img).float()
95
+ img = img.unsqueeze(0).to(device)
96
+ with torch.no_grad(), precision_scope("cuda"):
97
+ _, _, h_old, w_old = img.size()
98
+ h_pad = (h_old // window_size + 1) * window_size - h_old
99
+ w_pad = (w_old // window_size + 1) * window_size - w_old
100
+ img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
101
+ img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
102
+ output = inference(img, model, tile, tile_overlap, window_size, scale)
103
+ output = output[..., : h_old * scale, : w_old * scale]
104
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
105
+ if output.ndim == 3:
106
+ output = np.transpose(
107
+ output[[2, 1, 0], :, :], (1, 2, 0)
108
+ ) # CHW-RGB to HCW-BGR
109
+ output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
110
+ return Image.fromarray(output, "RGB")
111
+
112
+
113
+ def inference(img, model, tile, tile_overlap, window_size, scale):
114
+ # test the image tile by tile
115
+ b, c, h, w = img.size()
116
+ tile = min(tile, h, w)
117
+ assert tile % window_size == 0, "tile size should be a multiple of window_size"
118
+ sf = scale
119
+
120
+ stride = tile - tile_overlap
121
+ h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
122
+ w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
123
+ E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
124
+ W = torch.zeros_like(E, dtype=torch.half, device=device)
125
+
126
+ with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
127
+ for h_idx in h_idx_list:
128
+ for w_idx in w_idx_list:
129
+ in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
130
+ out_patch = model(in_patch)
131
+ out_patch_mask = torch.ones_like(out_patch)
132
+
133
+ E[
134
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
135
+ ].add_(out_patch)
136
+ W[
137
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
138
+ ].add_(out_patch_mask)
139
+ pbar.update(1)
140
+ output = E.div_(W)
141
+
142
+ return output
modules/swinir_model_arch.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------------
2
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
3
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
4
+ # -----------------------------------------------------------------------------------
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12
+
13
+
14
+ class Mlp(nn.Module):
15
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
16
+ super().__init__()
17
+ out_features = out_features or in_features
18
+ hidden_features = hidden_features or in_features
19
+ self.fc1 = nn.Linear(in_features, hidden_features)
20
+ self.act = act_layer()
21
+ self.fc2 = nn.Linear(hidden_features, out_features)
22
+ self.drop = nn.Dropout(drop)
23
+
24
+ def forward(self, x):
25
+ x = self.fc1(x)
26
+ x = self.act(x)
27
+ x = self.drop(x)
28
+ x = self.fc2(x)
29
+ x = self.drop(x)
30
+ return x
31
+
32
+
33
+ def window_partition(x, window_size):
34
+ """
35
+ Args:
36
+ x: (B, H, W, C)
37
+ window_size (int): window size
38
+
39
+ Returns:
40
+ windows: (num_windows*B, window_size, window_size, C)
41
+ """
42
+ B, H, W, C = x.shape
43
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
44
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
45
+ return windows
46
+
47
+
48
+ def window_reverse(windows, window_size, H, W):
49
+ """
50
+ Args:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ window_size (int): Window size
53
+ H (int): Height of image
54
+ W (int): Width of image
55
+
56
+ Returns:
57
+ x: (B, H, W, C)
58
+ """
59
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
60
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
61
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
62
+ return x
63
+
64
+
65
+ class WindowAttention(nn.Module):
66
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
67
+ It supports both of shifted and non-shifted window.
68
+
69
+ Args:
70
+ dim (int): Number of input channels.
71
+ window_size (tuple[int]): The height and width of the window.
72
+ num_heads (int): Number of attention heads.
73
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
74
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
75
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
76
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
77
+ """
78
+
79
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
80
+
81
+ super().__init__()
82
+ self.dim = dim
83
+ self.window_size = window_size # Wh, Ww
84
+ self.num_heads = num_heads
85
+ head_dim = dim // num_heads
86
+ self.scale = qk_scale or head_dim ** -0.5
87
+
88
+ # define a parameter table of relative position bias
89
+ self.relative_position_bias_table = nn.Parameter(
90
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
91
+
92
+ # get pair-wise relative position index for each token inside the window
93
+ coords_h = torch.arange(self.window_size[0])
94
+ coords_w = torch.arange(self.window_size[1])
95
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
96
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
97
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
98
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
99
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
100
+ relative_coords[:, :, 1] += self.window_size[1] - 1
101
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
102
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
103
+ self.register_buffer("relative_position_index", relative_position_index)
104
+
105
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(dim, dim)
108
+
109
+ self.proj_drop = nn.Dropout(proj_drop)
110
+
111
+ trunc_normal_(self.relative_position_bias_table, std=.02)
112
+ self.softmax = nn.Softmax(dim=-1)
113
+
114
+ def forward(self, x, mask=None):
115
+ """
116
+ Args:
117
+ x: input features with shape of (num_windows*B, N, C)
118
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
119
+ """
120
+ B_, N, C = x.shape
121
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
122
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
123
+
124
+ q = q * self.scale
125
+ attn = (q @ k.transpose(-2, -1))
126
+
127
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
128
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
129
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
130
+ attn = attn + relative_position_bias.unsqueeze(0)
131
+
132
+ if mask is not None:
133
+ nW = mask.shape[0]
134
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
135
+ attn = attn.view(-1, self.num_heads, N, N)
136
+ attn = self.softmax(attn)
137
+ else:
138
+ attn = self.softmax(attn)
139
+
140
+ attn = self.attn_drop(attn)
141
+
142
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
143
+ x = self.proj(x)
144
+ x = self.proj_drop(x)
145
+ return x
146
+
147
+ def extra_repr(self) -> str:
148
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
149
+
150
+ def flops(self, N):
151
+ # calculate flops for 1 window with token length of N
152
+ flops = 0
153
+ # qkv = self.qkv(x)
154
+ flops += N * self.dim * 3 * self.dim
155
+ # attn = (q @ k.transpose(-2, -1))
156
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
157
+ # x = (attn @ v)
158
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
159
+ # x = self.proj(x)
160
+ flops += N * self.dim * self.dim
161
+ return flops
162
+
163
+
164
+ class SwinTransformerBlock(nn.Module):
165
+ r""" Swin Transformer Block.
166
+
167
+ Args:
168
+ dim (int): Number of input channels.
169
+ input_resolution (tuple[int]): Input resulotion.
170
+ num_heads (int): Number of attention heads.
171
+ window_size (int): Window size.
172
+ shift_size (int): Shift size for SW-MSA.
173
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
174
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
175
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
176
+ drop (float, optional): Dropout rate. Default: 0.0
177
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
178
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
179
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
180
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
181
+ """
182
+
183
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
184
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
185
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
186
+ super().__init__()
187
+ self.dim = dim
188
+ self.input_resolution = input_resolution
189
+ self.num_heads = num_heads
190
+ self.window_size = window_size
191
+ self.shift_size = shift_size
192
+ self.mlp_ratio = mlp_ratio
193
+ if min(self.input_resolution) <= self.window_size:
194
+ # if window size is larger than input resolution, we don't partition windows
195
+ self.shift_size = 0
196
+ self.window_size = min(self.input_resolution)
197
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
198
+
199
+ self.norm1 = norm_layer(dim)
200
+ self.attn = WindowAttention(
201
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
202
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
203
+
204
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
205
+ self.norm2 = norm_layer(dim)
206
+ mlp_hidden_dim = int(dim * mlp_ratio)
207
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
208
+
209
+ if self.shift_size > 0:
210
+ attn_mask = self.calculate_mask(self.input_resolution)
211
+ else:
212
+ attn_mask = None
213
+
214
+ self.register_buffer("attn_mask", attn_mask)
215
+
216
+ def calculate_mask(self, x_size):
217
+ # calculate attention mask for SW-MSA
218
+ H, W = x_size
219
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
220
+ h_slices = (slice(0, -self.window_size),
221
+ slice(-self.window_size, -self.shift_size),
222
+ slice(-self.shift_size, None))
223
+ w_slices = (slice(0, -self.window_size),
224
+ slice(-self.window_size, -self.shift_size),
225
+ slice(-self.shift_size, None))
226
+ cnt = 0
227
+ for h in h_slices:
228
+ for w in w_slices:
229
+ img_mask[:, h, w, :] = cnt
230
+ cnt += 1
231
+
232
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
233
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
234
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
235
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
236
+
237
+ return attn_mask
238
+
239
+ def forward(self, x, x_size):
240
+ H, W = x_size
241
+ B, L, C = x.shape
242
+ # assert L == H * W, "input feature has wrong size"
243
+
244
+ shortcut = x
245
+ x = self.norm1(x)
246
+ x = x.view(B, H, W, C)
247
+
248
+ # cyclic shift
249
+ if self.shift_size > 0:
250
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
251
+ else:
252
+ shifted_x = x
253
+
254
+ # partition windows
255
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
256
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
257
+
258
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
259
+ if self.input_resolution == x_size:
260
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
261
+ else:
262
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
263
+
264
+ # merge windows
265
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
266
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
267
+
268
+ # reverse cyclic shift
269
+ if self.shift_size > 0:
270
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
271
+ else:
272
+ x = shifted_x
273
+ x = x.view(B, H * W, C)
274
+
275
+ # FFN
276
+ x = shortcut + self.drop_path(x)
277
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
278
+
279
+ return x
280
+
281
+ def extra_repr(self) -> str:
282
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
283
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
284
+
285
+ def flops(self):
286
+ flops = 0
287
+ H, W = self.input_resolution
288
+ # norm1
289
+ flops += self.dim * H * W
290
+ # W-MSA/SW-MSA
291
+ nW = H * W / self.window_size / self.window_size
292
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
293
+ # mlp
294
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
295
+ # norm2
296
+ flops += self.dim * H * W
297
+ return flops
298
+
299
+
300
+ class PatchMerging(nn.Module):
301
+ r""" Patch Merging Layer.
302
+
303
+ Args:
304
+ input_resolution (tuple[int]): Resolution of input feature.
305
+ dim (int): Number of input channels.
306
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
307
+ """
308
+
309
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
310
+ super().__init__()
311
+ self.input_resolution = input_resolution
312
+ self.dim = dim
313
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
314
+ self.norm = norm_layer(4 * dim)
315
+
316
+ def forward(self, x):
317
+ """
318
+ x: B, H*W, C
319
+ """
320
+ H, W = self.input_resolution
321
+ B, L, C = x.shape
322
+ assert L == H * W, "input feature has wrong size"
323
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
324
+
325
+ x = x.view(B, H, W, C)
326
+
327
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
328
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
329
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
330
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
331
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
332
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
333
+
334
+ x = self.norm(x)
335
+ x = self.reduction(x)
336
+
337
+ return x
338
+
339
+ def extra_repr(self) -> str:
340
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
341
+
342
+ def flops(self):
343
+ H, W = self.input_resolution
344
+ flops = H * W * self.dim
345
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
346
+ return flops
347
+
348
+
349
+ class BasicLayer(nn.Module):
350
+ """ A basic Swin Transformer layer for one stage.
351
+
352
+ Args:
353
+ dim (int): Number of input channels.
354
+ input_resolution (tuple[int]): Input resolution.
355
+ depth (int): Number of blocks.
356
+ num_heads (int): Number of attention heads.
357
+ window_size (int): Local window size.
358
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
359
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
360
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
361
+ drop (float, optional): Dropout rate. Default: 0.0
362
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
363
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
364
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
365
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
366
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
367
+ """
368
+
369
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
370
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
371
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
372
+
373
+ super().__init__()
374
+ self.dim = dim
375
+ self.input_resolution = input_resolution
376
+ self.depth = depth
377
+ self.use_checkpoint = use_checkpoint
378
+
379
+ # build blocks
380
+ self.blocks = nn.ModuleList([
381
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
382
+ num_heads=num_heads, window_size=window_size,
383
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
384
+ mlp_ratio=mlp_ratio,
385
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
386
+ drop=drop, attn_drop=attn_drop,
387
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
388
+ norm_layer=norm_layer)
389
+ for i in range(depth)])
390
+
391
+ # patch merging layer
392
+ if downsample is not None:
393
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
394
+ else:
395
+ self.downsample = None
396
+
397
+ def forward(self, x, x_size):
398
+ for blk in self.blocks:
399
+ if self.use_checkpoint:
400
+ x = checkpoint.checkpoint(blk, x, x_size)
401
+ else:
402
+ x = blk(x, x_size)
403
+ if self.downsample is not None:
404
+ x = self.downsample(x)
405
+ return x
406
+
407
+ def extra_repr(self) -> str:
408
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
409
+
410
+ def flops(self):
411
+ flops = 0
412
+ for blk in self.blocks:
413
+ flops += blk.flops()
414
+ if self.downsample is not None:
415
+ flops += self.downsample.flops()
416
+ return flops
417
+
418
+
419
+ class RSTB(nn.Module):
420
+ """Residual Swin Transformer Block (RSTB).
421
+
422
+ Args:
423
+ dim (int): Number of input channels.
424
+ input_resolution (tuple[int]): Input resolution.
425
+ depth (int): Number of blocks.
426
+ num_heads (int): Number of attention heads.
427
+ window_size (int): Local window size.
428
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
429
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
430
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
431
+ drop (float, optional): Dropout rate. Default: 0.0
432
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
433
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
434
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
435
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
436
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
437
+ img_size: Input image size.
438
+ patch_size: Patch size.
439
+ resi_connection: The convolutional block before residual connection.
440
+ """
441
+
442
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
443
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
444
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
445
+ img_size=224, patch_size=4, resi_connection='1conv'):
446
+ super(RSTB, self).__init__()
447
+
448
+ self.dim = dim
449
+ self.input_resolution = input_resolution
450
+
451
+ self.residual_group = BasicLayer(dim=dim,
452
+ input_resolution=input_resolution,
453
+ depth=depth,
454
+ num_heads=num_heads,
455
+ window_size=window_size,
456
+ mlp_ratio=mlp_ratio,
457
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
458
+ drop=drop, attn_drop=attn_drop,
459
+ drop_path=drop_path,
460
+ norm_layer=norm_layer,
461
+ downsample=downsample,
462
+ use_checkpoint=use_checkpoint)
463
+
464
+ if resi_connection == '1conv':
465
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
466
+ elif resi_connection == '3conv':
467
+ # to save parameters and memory
468
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
469
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
470
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
471
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
472
+
473
+ self.patch_embed = PatchEmbed(
474
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
475
+ norm_layer=None)
476
+
477
+ self.patch_unembed = PatchUnEmbed(
478
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
479
+ norm_layer=None)
480
+
481
+ def forward(self, x, x_size):
482
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
483
+
484
+ def flops(self):
485
+ flops = 0
486
+ flops += self.residual_group.flops()
487
+ H, W = self.input_resolution
488
+ flops += H * W * self.dim * self.dim * 9
489
+ flops += self.patch_embed.flops()
490
+ flops += self.patch_unembed.flops()
491
+
492
+ return flops
493
+
494
+
495
+ class PatchEmbed(nn.Module):
496
+ r""" Image to Patch Embedding
497
+
498
+ Args:
499
+ img_size (int): Image size. Default: 224.
500
+ patch_size (int): Patch token size. Default: 4.
501
+ in_chans (int): Number of input image channels. Default: 3.
502
+ embed_dim (int): Number of linear projection output channels. Default: 96.
503
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
504
+ """
505
+
506
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
507
+ super().__init__()
508
+ img_size = to_2tuple(img_size)
509
+ patch_size = to_2tuple(patch_size)
510
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
511
+ self.img_size = img_size
512
+ self.patch_size = patch_size
513
+ self.patches_resolution = patches_resolution
514
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
515
+
516
+ self.in_chans = in_chans
517
+ self.embed_dim = embed_dim
518
+
519
+ if norm_layer is not None:
520
+ self.norm = norm_layer(embed_dim)
521
+ else:
522
+ self.norm = None
523
+
524
+ def forward(self, x):
525
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
526
+ if self.norm is not None:
527
+ x = self.norm(x)
528
+ return x
529
+
530
+ def flops(self):
531
+ flops = 0
532
+ H, W = self.img_size
533
+ if self.norm is not None:
534
+ flops += H * W * self.embed_dim
535
+ return flops
536
+
537
+
538
+ class PatchUnEmbed(nn.Module):
539
+ r""" Image to Patch Unembedding
540
+
541
+ Args:
542
+ img_size (int): Image size. Default: 224.
543
+ patch_size (int): Patch token size. Default: 4.
544
+ in_chans (int): Number of input image channels. Default: 3.
545
+ embed_dim (int): Number of linear projection output channels. Default: 96.
546
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
547
+ """
548
+
549
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
550
+ super().__init__()
551
+ img_size = to_2tuple(img_size)
552
+ patch_size = to_2tuple(patch_size)
553
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
554
+ self.img_size = img_size
555
+ self.patch_size = patch_size
556
+ self.patches_resolution = patches_resolution
557
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
558
+
559
+ self.in_chans = in_chans
560
+ self.embed_dim = embed_dim
561
+
562
+ def forward(self, x, x_size):
563
+ B, HW, C = x.shape
564
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
565
+ return x
566
+
567
+ def flops(self):
568
+ flops = 0
569
+ return flops
570
+
571
+
572
+ class Upsample(nn.Sequential):
573
+ """Upsample module.
574
+
575
+ Args:
576
+ scale (int): Scale factor. Supported scales: 2^n and 3.
577
+ num_feat (int): Channel number of intermediate features.
578
+ """
579
+
580
+ def __init__(self, scale, num_feat):
581
+ m = []
582
+ if (scale & (scale - 1)) == 0: # scale = 2^n
583
+ for _ in range(int(math.log(scale, 2))):
584
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
585
+ m.append(nn.PixelShuffle(2))
586
+ elif scale == 3:
587
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
588
+ m.append(nn.PixelShuffle(3))
589
+ else:
590
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
591
+ super(Upsample, self).__init__(*m)
592
+
593
+
594
+ class UpsampleOneStep(nn.Sequential):
595
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
596
+ Used in lightweight SR to save parameters.
597
+
598
+ Args:
599
+ scale (int): Scale factor. Supported scales: 2^n and 3.
600
+ num_feat (int): Channel number of intermediate features.
601
+
602
+ """
603
+
604
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
605
+ self.num_feat = num_feat
606
+ self.input_resolution = input_resolution
607
+ m = []
608
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
609
+ m.append(nn.PixelShuffle(scale))
610
+ super(UpsampleOneStep, self).__init__(*m)
611
+
612
+ def flops(self):
613
+ H, W = self.input_resolution
614
+ flops = H * W * self.num_feat * 3 * 9
615
+ return flops
616
+
617
+
618
+ class SwinIR(nn.Module):
619
+ r""" SwinIR
620
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
621
+
622
+ Args:
623
+ img_size (int | tuple(int)): Input image size. Default 64
624
+ patch_size (int | tuple(int)): Patch size. Default: 1
625
+ in_chans (int): Number of input image channels. Default: 3
626
+ embed_dim (int): Patch embedding dimension. Default: 96
627
+ depths (tuple(int)): Depth of each Swin Transformer layer.
628
+ num_heads (tuple(int)): Number of attention heads in different layers.
629
+ window_size (int): Window size. Default: 7
630
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
631
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
632
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
633
+ drop_rate (float): Dropout rate. Default: 0
634
+ attn_drop_rate (float): Attention dropout rate. Default: 0
635
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
636
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
637
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
638
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
639
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
640
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
641
+ img_range: Image range. 1. or 255.
642
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
643
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
644
+ """
645
+
646
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
647
+ embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
648
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
649
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
650
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
651
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
652
+ **kwargs):
653
+ super(SwinIR, self).__init__()
654
+ num_in_ch = in_chans
655
+ num_out_ch = in_chans
656
+ num_feat = 64
657
+ self.img_range = img_range
658
+ if in_chans == 3:
659
+ rgb_mean = (0.4488, 0.4371, 0.4040)
660
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
661
+ else:
662
+ self.mean = torch.zeros(1, 1, 1, 1)
663
+ self.upscale = upscale
664
+ self.upsampler = upsampler
665
+ self.window_size = window_size
666
+
667
+ #####################################################################################################
668
+ ################################### 1, shallow feature extraction ###################################
669
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
670
+
671
+ #####################################################################################################
672
+ ################################### 2, deep feature extraction ######################################
673
+ self.num_layers = len(depths)
674
+ self.embed_dim = embed_dim
675
+ self.ape = ape
676
+ self.patch_norm = patch_norm
677
+ self.num_features = embed_dim
678
+ self.mlp_ratio = mlp_ratio
679
+
680
+ # split image into non-overlapping patches
681
+ self.patch_embed = PatchEmbed(
682
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
683
+ norm_layer=norm_layer if self.patch_norm else None)
684
+ num_patches = self.patch_embed.num_patches
685
+ patches_resolution = self.patch_embed.patches_resolution
686
+ self.patches_resolution = patches_resolution
687
+
688
+ # merge non-overlapping patches into image
689
+ self.patch_unembed = PatchUnEmbed(
690
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
691
+ norm_layer=norm_layer if self.patch_norm else None)
692
+
693
+ # absolute position embedding
694
+ if self.ape:
695
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
696
+ trunc_normal_(self.absolute_pos_embed, std=.02)
697
+
698
+ self.pos_drop = nn.Dropout(p=drop_rate)
699
+
700
+ # stochastic depth
701
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
702
+
703
+ # build Residual Swin Transformer blocks (RSTB)
704
+ self.layers = nn.ModuleList()
705
+ for i_layer in range(self.num_layers):
706
+ layer = RSTB(dim=embed_dim,
707
+ input_resolution=(patches_resolution[0],
708
+ patches_resolution[1]),
709
+ depth=depths[i_layer],
710
+ num_heads=num_heads[i_layer],
711
+ window_size=window_size,
712
+ mlp_ratio=self.mlp_ratio,
713
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
714
+ drop=drop_rate, attn_drop=attn_drop_rate,
715
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
716
+ norm_layer=norm_layer,
717
+ downsample=None,
718
+ use_checkpoint=use_checkpoint,
719
+ img_size=img_size,
720
+ patch_size=patch_size,
721
+ resi_connection=resi_connection
722
+
723
+ )
724
+ self.layers.append(layer)
725
+ self.norm = norm_layer(self.num_features)
726
+
727
+ # build the last conv layer in deep feature extraction
728
+ if resi_connection == '1conv':
729
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
730
+ elif resi_connection == '3conv':
731
+ # to save parameters and memory
732
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
733
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
734
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
735
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
736
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
737
+
738
+ #####################################################################################################
739
+ ################################ 3, high quality image reconstruction ################################
740
+ if self.upsampler == 'pixelshuffle':
741
+ # for classical SR
742
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
743
+ nn.LeakyReLU(inplace=True))
744
+ self.upsample = Upsample(upscale, num_feat)
745
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
746
+ elif self.upsampler == 'pixelshuffledirect':
747
+ # for lightweight SR (to save parameters)
748
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
749
+ (patches_resolution[0], patches_resolution[1]))
750
+ elif self.upsampler == 'nearest+conv':
751
+ # for real-world SR (less artifacts)
752
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
753
+ nn.LeakyReLU(inplace=True))
754
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
755
+ if self.upscale == 4:
756
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
757
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
758
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
759
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
760
+ else:
761
+ # for image denoising and JPEG compression artifact reduction
762
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
763
+
764
+ self.apply(self._init_weights)
765
+
766
+ def _init_weights(self, m):
767
+ if isinstance(m, nn.Linear):
768
+ trunc_normal_(m.weight, std=.02)
769
+ if isinstance(m, nn.Linear) and m.bias is not None:
770
+ nn.init.constant_(m.bias, 0)
771
+ elif isinstance(m, nn.LayerNorm):
772
+ nn.init.constant_(m.bias, 0)
773
+ nn.init.constant_(m.weight, 1.0)
774
+
775
+ @torch.jit.ignore
776
+ def no_weight_decay(self):
777
+ return {'absolute_pos_embed'}
778
+
779
+ @torch.jit.ignore
780
+ def no_weight_decay_keywords(self):
781
+ return {'relative_position_bias_table'}
782
+
783
+ def check_image_size(self, x):
784
+ _, _, h, w = x.size()
785
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
786
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
787
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
788
+ return x
789
+
790
+ def forward_features(self, x):
791
+ x_size = (x.shape[2], x.shape[3])
792
+ x = self.patch_embed(x)
793
+ if self.ape:
794
+ x = x + self.absolute_pos_embed
795
+ x = self.pos_drop(x)
796
+
797
+ for layer in self.layers:
798
+ x = layer(x, x_size)
799
+
800
+ x = self.norm(x) # B L C
801
+ x = self.patch_unembed(x, x_size)
802
+
803
+ return x
804
+
805
+ def forward(self, x):
806
+ H, W = x.shape[2:]
807
+ x = self.check_image_size(x)
808
+
809
+ self.mean = self.mean.type_as(x)
810
+ x = (x - self.mean) * self.img_range
811
+
812
+ if self.upsampler == 'pixelshuffle':
813
+ # for classical SR
814
+ x = self.conv_first(x)
815
+ x = self.conv_after_body(self.forward_features(x)) + x
816
+ x = self.conv_before_upsample(x)
817
+ x = self.conv_last(self.upsample(x))
818
+ elif self.upsampler == 'pixelshuffledirect':
819
+ # for lightweight SR
820
+ x = self.conv_first(x)
821
+ x = self.conv_after_body(self.forward_features(x)) + x
822
+ x = self.upsample(x)
823
+ elif self.upsampler == 'nearest+conv':
824
+ # for real-world SR
825
+ x = self.conv_first(x)
826
+ x = self.conv_after_body(self.forward_features(x)) + x
827
+ x = self.conv_before_upsample(x)
828
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
829
+ if self.upscale == 4:
830
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
831
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
832
+ else:
833
+ # for image denoising and JPEG compression artifact reduction
834
+ x_first = self.conv_first(x)
835
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
836
+ x = x + self.conv_last(res)
837
+
838
+ x = x / self.img_range + self.mean
839
+
840
+ return x[:, :, :H*self.upscale, :W*self.upscale]
841
+
842
+ def flops(self):
843
+ flops = 0
844
+ H, W = self.patches_resolution
845
+ flops += H * W * 3 * self.embed_dim * 9
846
+ flops += self.patch_embed.flops()
847
+ for i, layer in enumerate(self.layers):
848
+ flops += layer.flops()
849
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
850
+ flops += self.upsample.flops()
851
+ return flops
852
+
853
+
854
+ if __name__ == '__main__':
855
+ upscale = 4
856
+ window_size = 8
857
+ height = (1024 // upscale // window_size + 1) * window_size
858
+ width = (720 // upscale // window_size + 1) * window_size
859
+ model = SwinIR(upscale=2, img_size=(height, width),
860
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
861
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
862
+ print(model)
863
+ print(height, width, model.flops() / 1e9)
864
+
865
+ x = torch.randn((1, 3, height, width))
866
+ x = model(x)
867
+ print(x.shape)
modules/txt2img.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modules.scripts
2
+ from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
3
+ from modules.shared import opts, cmd_opts
4
+ import modules.shared as shared
5
+ import modules.processing as processing
6
+ from modules.ui import plaintext_to_html
7
+
8
+
9
+ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args):
10
+ p = StableDiffusionProcessingTxt2Img(
11
+ sd_model=shared.sd_model,
12
+ outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
13
+ outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
14
+ prompt=prompt,
15
+ styles=[prompt_style, prompt_style2],
16
+ negative_prompt=negative_prompt,
17
+ seed=seed,
18
+ subseed=subseed,
19
+ subseed_strength=subseed_strength,
20
+ seed_resize_from_h=seed_resize_from_h,
21
+ seed_resize_from_w=seed_resize_from_w,
22
+ seed_enable_extras=seed_enable_extras,
23
+ sampler_index=sampler_index,
24
+ batch_size=batch_size,
25
+ n_iter=n_iter,
26
+ steps=steps,
27
+ cfg_scale=cfg_scale,
28
+ width=width,
29
+ height=height,
30
+ restore_faces=restore_faces,
31
+ tiling=tiling,
32
+ enable_hr=enable_hr,
33
+ scale_latent=scale_latent if enable_hr else None,
34
+ denoising_strength=denoising_strength if enable_hr else None,
35
+ )
36
+
37
+ if cmd_opts.enable_console_prompts:
38
+ print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
39
+
40
+ processed = modules.scripts.scripts_txt2img.run(p, *args)
41
+
42
+ if processed is None:
43
+ processed = process_images(p)
44
+
45
+ shared.total_tqdm.clear()
46
+
47
+ generation_info_js = processed.js()
48
+ if opts.samples_log_stdout:
49
+ print(generation_info_js)
50
+
51
+ if opts.do_not_show_images:
52
+ processed.images = []
53
+
54
+ return processed.images, generation_info_js, plaintext_to_html(processed.info)
55
+
modules/ui.py ADDED
@@ -0,0 +1,1414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import html
3
+ import io
4
+ import json
5
+ import math
6
+ import mimetypes
7
+ import os
8
+ import random
9
+ import sys
10
+ import time
11
+ import traceback
12
+ import platform
13
+ import subprocess as sp
14
+ from functools import reduce
15
+
16
+ import numpy as np
17
+ import torch
18
+ from PIL import Image, PngImagePlugin
19
+ import piexif
20
+
21
+ import gradio as gr
22
+ import gradio.utils
23
+ import gradio.routes
24
+
25
+ from modules import sd_hijack
26
+ from modules.paths import script_path
27
+ from modules.shared import opts, cmd_opts
28
+ import modules.shared as shared
29
+ from modules.sd_samplers import samplers, samplers_for_img2img
30
+ from modules.sd_hijack import model_hijack
31
+ import modules.ldsr_model
32
+ import modules.scripts
33
+ import modules.gfpgan_model
34
+ import modules.codeformer_model
35
+ import modules.styles
36
+ import modules.generation_parameters_copypaste
37
+ from modules import prompt_parser
38
+ from modules.images import save_image
39
+ import modules.textual_inversion.ui
40
+
41
+ # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
42
+ mimetypes.init()
43
+ mimetypes.add_type('application/javascript', '.js')
44
+
45
+
46
+ if not cmd_opts.share and not cmd_opts.listen:
47
+ # fix gradio phoning home
48
+ gradio.utils.version_check = lambda: None
49
+ gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
50
+
51
+
52
+ def gr_show(visible=True):
53
+ return {"visible": visible, "__type__": "update"}
54
+
55
+
56
+ sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
57
+ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
58
+
59
+ css_hide_progressbar = """
60
+ .wrap .m-12 svg { display:none!important; }
61
+ .wrap .m-12::before { content:"Loading..." }
62
+ .progress-bar { display:none!important; }
63
+ .meta-text { display:none!important; }
64
+ """
65
+
66
+ # Using constants for these since the variation selector isn't visible.
67
+ # Important that they exactly match script.js for tooltip to work.
68
+ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
69
+ reuse_symbol = '\u267b\ufe0f' # ♻️
70
+ art_symbol = '\U0001f3a8' # 🎨
71
+ paste_symbol = '\u2199\ufe0f' # ↙
72
+ folder_symbol = '\U0001f4c2' # 📂
73
+
74
+ def plaintext_to_html(text):
75
+ text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
76
+ return text
77
+
78
+
79
+ def image_from_url_text(filedata):
80
+ if type(filedata) == list:
81
+ if len(filedata) == 0:
82
+ return None
83
+
84
+ filedata = filedata[0]
85
+
86
+ if filedata.startswith("data:image/png;base64,"):
87
+ filedata = filedata[len("data:image/png;base64,"):]
88
+
89
+ filedata = base64.decodebytes(filedata.encode('utf-8'))
90
+ image = Image.open(io.BytesIO(filedata))
91
+ return image
92
+
93
+
94
+ def send_gradio_gallery_to_image(x):
95
+ if len(x) == 0:
96
+ return None
97
+
98
+ return image_from_url_text(x[0])
99
+
100
+
101
+ def save_files(js_data, images, index):
102
+ import csv
103
+ filenames = []
104
+
105
+ #quick dictionary to class object conversion. Its neccesary due apply_filename_pattern requiring it
106
+ class MyObject:
107
+ def __init__(self, d=None):
108
+ if d is not None:
109
+ for key, value in d.items():
110
+ setattr(self, key, value)
111
+
112
+ data = json.loads(js_data)
113
+
114
+ p = MyObject(data)
115
+ path = opts.outdir_save
116
+ save_to_dirs = opts.use_save_to_dirs_for_ui
117
+ extension: str = opts.samples_format
118
+ start_index = 0
119
+
120
+ if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
121
+
122
+ images = [images[index]]
123
+ start_index = index
124
+
125
+ with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
126
+ at_start = file.tell() == 0
127
+ writer = csv.writer(file)
128
+ if at_start:
129
+ writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
130
+
131
+ for image_index, filedata in enumerate(images, start_index):
132
+ if filedata.startswith("data:image/png;base64,"):
133
+ filedata = filedata[len("data:image/png;base64,"):]
134
+
135
+ image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
136
+
137
+ is_grid = image_index < p.index_of_first_image
138
+ i = 0 if is_grid else (image_index - p.index_of_first_image)
139
+
140
+ fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
141
+
142
+ filename = os.path.relpath(fullfn, path)
143
+ filenames.append(filename)
144
+
145
+ writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
146
+
147
+ return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
148
+
149
+
150
+ def wrap_gradio_call(func, extra_outputs=None):
151
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
152
+ run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
153
+ if run_memmon:
154
+ shared.mem_mon.monitor()
155
+ t = time.perf_counter()
156
+
157
+ try:
158
+ res = list(func(*args, **kwargs))
159
+ except Exception as e:
160
+ print("Error completing request", file=sys.stderr)
161
+ print("Arguments:", args, kwargs, file=sys.stderr)
162
+ print(traceback.format_exc(), file=sys.stderr)
163
+
164
+ shared.state.job = ""
165
+ shared.state.job_count = 0
166
+
167
+ if extra_outputs_array is None:
168
+ extra_outputs_array = [None, '']
169
+
170
+ res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
171
+
172
+ elapsed = time.perf_counter() - t
173
+ elapsed_m = int(elapsed // 60)
174
+ elapsed_s = elapsed % 60
175
+ elapsed_text = f"{elapsed_s:.2f}s"
176
+ if (elapsed_m > 0):
177
+ elapsed_text = f"{elapsed_m}m "+elapsed_text
178
+
179
+ if run_memmon:
180
+ mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
181
+ active_peak = mem_stats['active_peak']
182
+ reserved_peak = mem_stats['reserved_peak']
183
+ sys_peak = mem_stats['system_peak']
184
+ sys_total = mem_stats['total']
185
+ sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
186
+
187
+ vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
188
+ else:
189
+ vram_html = ''
190
+
191
+ # last item is always HTML
192
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
193
+
194
+ shared.state.interrupted = False
195
+ shared.state.job_count = 0
196
+
197
+ return tuple(res)
198
+
199
+ return f
200
+
201
+
202
+ def check_progress_call(id_part):
203
+ if shared.state.job_count == 0:
204
+ return "", gr_show(False), gr_show(False), gr_show(False)
205
+
206
+ progress = 0
207
+
208
+ if shared.state.job_count > 0:
209
+ progress += shared.state.job_no / shared.state.job_count
210
+ if shared.state.sampling_steps > 0:
211
+ progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
212
+
213
+ progress = min(progress, 1)
214
+
215
+ progressbar = ""
216
+ if opts.show_progressbar:
217
+ progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>"""
218
+
219
+ image = gr_show(False)
220
+ preview_visibility = gr_show(False)
221
+
222
+ if opts.show_progress_every_n_steps > 0:
223
+ if shared.parallel_processing_allowed:
224
+
225
+ if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
226
+ shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
227
+ shared.state.current_image_sampling_step = shared.state.sampling_step
228
+
229
+ image = shared.state.current_image
230
+
231
+ if image is None:
232
+ image = gr.update(value=None)
233
+ else:
234
+ preview_visibility = gr_show(True)
235
+
236
+ if shared.state.textinfo is not None:
237
+ textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
238
+ else:
239
+ textinfo_result = gr_show(False)
240
+
241
+ return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
242
+
243
+
244
+ def check_progress_call_initial(id_part):
245
+ shared.state.job_count = -1
246
+ shared.state.current_latent = None
247
+ shared.state.current_image = None
248
+ shared.state.textinfo = None
249
+
250
+ return check_progress_call(id_part)
251
+
252
+
253
+ def roll_artist(prompt):
254
+ allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories])
255
+ artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
256
+
257
+ return prompt + ", " + artist.name if prompt != '' else artist.name
258
+
259
+
260
+ def visit(x, func, path=""):
261
+ if hasattr(x, 'children'):
262
+ for c in x.children:
263
+ visit(c, func, path)
264
+ elif x.label is not None:
265
+ func(path + "/" + str(x.label), x)
266
+
267
+
268
+ def add_style(name: str, prompt: str, negative_prompt: str):
269
+ if name is None:
270
+ return [gr_show(), gr_show()]
271
+
272
+ style = modules.styles.PromptStyle(name, prompt, negative_prompt)
273
+ shared.prompt_styles.styles[style.name] = style
274
+ # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
275
+ # reserialize all styles every time we save them
276
+ shared.prompt_styles.save_styles(shared.styles_filename)
277
+
278
+ return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
279
+
280
+
281
+ def apply_styles(prompt, prompt_neg, style1_name, style2_name):
282
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
283
+ prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
284
+
285
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
286
+
287
+
288
+ def interrogate(image):
289
+ prompt = shared.interrogator.interrogate(image)
290
+
291
+ return gr_show(True) if prompt is None else prompt
292
+
293
+
294
+ def create_seed_inputs():
295
+ with gr.Row():
296
+ with gr.Box():
297
+ with gr.Row(elem_id='seed_row'):
298
+ seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1)
299
+ seed.style(container=False)
300
+ random_seed = gr.Button(random_symbol, elem_id='random_seed')
301
+ reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed')
302
+
303
+ with gr.Box(elem_id='subseed_show_box'):
304
+ seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False)
305
+
306
+ # Components to show/hide based on the 'Extra' checkbox
307
+ seed_extras = []
308
+
309
+ with gr.Row(visible=False) as seed_extra_row_1:
310
+ seed_extras.append(seed_extra_row_1)
311
+ with gr.Box():
312
+ with gr.Row(elem_id='subseed_row'):
313
+ subseed = gr.Number(label='Variation seed', value=-1)
314
+ subseed.style(container=False)
315
+ random_subseed = gr.Button(random_symbol, elem_id='random_subseed')
316
+ reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed')
317
+ subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01)
318
+
319
+ with gr.Row(visible=False) as seed_extra_row_2:
320
+ seed_extras.append(seed_extra_row_2)
321
+ seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0)
322
+ seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0)
323
+
324
+ random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
325
+ random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
326
+
327
+ def change_visibility(show):
328
+ return {comp: gr_show(show) for comp in seed_extras}
329
+
330
+ seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
331
+
332
+ return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
333
+
334
+
335
+ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
336
+ """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
337
+ (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
338
+ was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
339
+ def copy_seed(gen_info_string: str, index):
340
+ res = -1
341
+
342
+ try:
343
+ gen_info = json.loads(gen_info_string)
344
+ index -= gen_info.get('index_of_first_image', 0)
345
+
346
+ if is_subseed and gen_info.get('subseed_strength', 0) > 0:
347
+ all_subseeds = gen_info.get('all_subseeds', [-1])
348
+ res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
349
+ else:
350
+ all_seeds = gen_info.get('all_seeds', [-1])
351
+ res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
352
+
353
+ except json.decoder.JSONDecodeError as e:
354
+ if gen_info_string != '':
355
+ print("Error parsing JSON generation info:", file=sys.stderr)
356
+ print(gen_info_string, file=sys.stderr)
357
+
358
+ return [res, gr_show(False)]
359
+
360
+ reuse_seed.click(
361
+ fn=copy_seed,
362
+ _js="(x, y) => [x, selected_gallery_index()]",
363
+ show_progress=False,
364
+ inputs=[generation_info, dummy_component],
365
+ outputs=[seed, dummy_component]
366
+ )
367
+
368
+
369
+ def update_token_counter(text, steps):
370
+ try:
371
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
372
+ prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
373
+
374
+ except Exception:
375
+ # a parsing error can happen here during typing, and we don't want to bother the user with
376
+ # messages related to it in console
377
+ prompt_schedules = [[[steps, text]]]
378
+
379
+ flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
380
+ prompts = [prompt_text for step, prompt_text in flat_prompts]
381
+ tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
382
+ style_class = ' class="red"' if (token_count > max_length) else ""
383
+ return f"<span {style_class}>{token_count}/{max_length}</span>"
384
+
385
+
386
+ def create_toprow(is_img2img):
387
+ id_part = "img2img" if is_img2img else "txt2img"
388
+
389
+ with gr.Row(elem_id="toprow"):
390
+ with gr.Column(scale=4):
391
+ with gr.Row():
392
+ with gr.Column(scale=80):
393
+ with gr.Row():
394
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
395
+
396
+ with gr.Column(scale=1, elem_id="roll_col"):
397
+ roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
398
+ paste = gr.Button(value=paste_symbol, elem_id="paste")
399
+ token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
400
+ token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
401
+
402
+ with gr.Column(scale=10, elem_id="style_pos_col"):
403
+ prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
404
+
405
+ with gr.Row():
406
+ with gr.Column(scale=8):
407
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
408
+
409
+ with gr.Column(scale=1, elem_id="style_neg_col"):
410
+ prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
411
+
412
+ with gr.Column(scale=1):
413
+ with gr.Row():
414
+ interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
415
+ submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
416
+
417
+ interrupt.click(
418
+ fn=lambda: shared.state.interrupt(),
419
+ inputs=[],
420
+ outputs=[],
421
+ )
422
+
423
+ with gr.Row():
424
+ if is_img2img:
425
+ interrogate = gr.Button('Interrogate', elem_id="interrogate")
426
+ else:
427
+ interrogate = None
428
+ prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
429
+ save_style = gr.Button('Create style', elem_id="style_create")
430
+
431
+ return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
432
+
433
+
434
+ def setup_progressbar(progressbar, preview, id_part, textinfo=None):
435
+ if textinfo is None:
436
+ textinfo = gr.HTML(visible=False)
437
+
438
+ check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
439
+ check_progress.click(
440
+ fn=lambda: check_progress_call(id_part),
441
+ show_progress=False,
442
+ inputs=[],
443
+ outputs=[progressbar, preview, preview, textinfo],
444
+ )
445
+
446
+ check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
447
+ check_progress_initial.click(
448
+ fn=lambda: check_progress_call_initial(id_part),
449
+ show_progress=False,
450
+ inputs=[],
451
+ outputs=[progressbar, preview, preview, textinfo],
452
+ )
453
+
454
+
455
+ def create_ui(wrap_gradio_gpu_call):
456
+ import modules.img2img
457
+ import modules.txt2img
458
+
459
+ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
460
+ txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
461
+ dummy_component = gr.Label(visible=False)
462
+
463
+ with gr.Row(elem_id='txt2img_progress_row'):
464
+ with gr.Column(scale=1):
465
+ pass
466
+
467
+ with gr.Column(scale=1):
468
+ progressbar = gr.HTML(elem_id="txt2img_progressbar")
469
+ txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
470
+ setup_progressbar(progressbar, txt2img_preview, 'txt2img')
471
+
472
+ with gr.Row().style(equal_height=False):
473
+ with gr.Column(variant='panel'):
474
+ steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
475
+ sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
476
+
477
+ with gr.Group():
478
+ width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
479
+ height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
480
+
481
+ with gr.Row():
482
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
483
+ tiling = gr.Checkbox(label='Tiling', value=False)
484
+ enable_hr = gr.Checkbox(label='Highres. fix', value=False)
485
+
486
+ with gr.Row(visible=False) as hr_options:
487
+ scale_latent = gr.Checkbox(label='Scale latent', value=False)
488
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
489
+
490
+ with gr.Row():
491
+ batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
492
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
493
+
494
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
495
+
496
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
497
+
498
+ with gr.Group():
499
+ custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
500
+
501
+ with gr.Column(variant='panel'):
502
+
503
+ with gr.Group():
504
+ txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
505
+ txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
506
+
507
+ with gr.Group():
508
+ with gr.Row():
509
+ save = gr.Button('Save')
510
+ send_to_img2img = gr.Button('Send to img2img')
511
+ send_to_inpaint = gr.Button('Send to inpaint')
512
+ send_to_extras = gr.Button('Send to extras')
513
+ button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
514
+ open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
515
+
516
+ with gr.Group():
517
+ html_info = gr.HTML()
518
+ generation_info = gr.Textbox(visible=False)
519
+
520
+ connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
521
+ connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
522
+
523
+ txt2img_args = dict(
524
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
525
+ _js="submit",
526
+ inputs=[
527
+ txt2img_prompt,
528
+ txt2img_negative_prompt,
529
+ txt2img_prompt_style,
530
+ txt2img_prompt_style2,
531
+ steps,
532
+ sampler_index,
533
+ restore_faces,
534
+ tiling,
535
+ batch_count,
536
+ batch_size,
537
+ cfg_scale,
538
+ seed,
539
+ subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
540
+ height,
541
+ width,
542
+ enable_hr,
543
+ scale_latent,
544
+ denoising_strength,
545
+ ] + custom_inputs,
546
+ outputs=[
547
+ txt2img_gallery,
548
+ generation_info,
549
+ html_info
550
+ ],
551
+ show_progress=False,
552
+ )
553
+
554
+ txt2img_prompt.submit(**txt2img_args)
555
+ submit.click(**txt2img_args)
556
+
557
+ enable_hr.change(
558
+ fn=lambda x: gr_show(x),
559
+ inputs=[enable_hr],
560
+ outputs=[hr_options],
561
+ )
562
+
563
+ save.click(
564
+ fn=wrap_gradio_call(save_files),
565
+ _js="(x, y, z) => [x, y, selected_gallery_index()]",
566
+ inputs=[
567
+ generation_info,
568
+ txt2img_gallery,
569
+ html_info,
570
+ ],
571
+ outputs=[
572
+ html_info,
573
+ html_info,
574
+ html_info,
575
+ ]
576
+ )
577
+
578
+ roll.click(
579
+ fn=roll_artist,
580
+ _js="update_txt2img_tokens",
581
+ inputs=[
582
+ txt2img_prompt,
583
+ ],
584
+ outputs=[
585
+ txt2img_prompt,
586
+ ]
587
+ )
588
+
589
+ txt2img_paste_fields = [
590
+ (txt2img_prompt, "Prompt"),
591
+ (txt2img_negative_prompt, "Negative prompt"),
592
+ (steps, "Steps"),
593
+ (sampler_index, "Sampler"),
594
+ (restore_faces, "Face restoration"),
595
+ (cfg_scale, "CFG scale"),
596
+ (seed, "Seed"),
597
+ (width, "Size-1"),
598
+ (height, "Size-2"),
599
+ (batch_size, "Batch size"),
600
+ (subseed, "Variation seed"),
601
+ (subseed_strength, "Variation seed strength"),
602
+ (seed_resize_from_w, "Seed resize from-1"),
603
+ (seed_resize_from_h, "Seed resize from-2"),
604
+ (denoising_strength, "Denoising strength"),
605
+ (enable_hr, lambda d: "Denoising strength" in d),
606
+ (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
607
+ ]
608
+ modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
609
+ token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
610
+
611
+ with gr.Blocks(analytics_enabled=False) as img2img_interface:
612
+ img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
613
+
614
+ with gr.Row(elem_id='img2img_progress_row'):
615
+ with gr.Column(scale=1):
616
+ pass
617
+
618
+ with gr.Column(scale=1):
619
+ progressbar = gr.HTML(elem_id="img2img_progressbar")
620
+ img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
621
+ setup_progressbar(progressbar, img2img_preview, 'img2img')
622
+
623
+ with gr.Row().style(equal_height=False):
624
+ with gr.Column(variant='panel'):
625
+
626
+ with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
627
+ with gr.TabItem('img2img', id='img2img'):
628
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
629
+
630
+ with gr.TabItem('Inpaint', id='inpaint'):
631
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
632
+
633
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
634
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
635
+
636
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
637
+
638
+ with gr.Row():
639
+ mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
640
+ inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
641
+
642
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
643
+
644
+ with gr.Row():
645
+ inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
646
+ inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32)
647
+
648
+ with gr.TabItem('Batch img2img', id='batch'):
649
+ hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
650
+ gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
651
+ img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
652
+ img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
653
+
654
+ with gr.Row():
655
+ resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
656
+
657
+ steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
658
+ sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
659
+
660
+ with gr.Group():
661
+ width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
662
+ height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
663
+
664
+ with gr.Row():
665
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
666
+ tiling = gr.Checkbox(label='Tiling', value=False)
667
+
668
+ with gr.Row():
669
+ batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
670
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
671
+
672
+ with gr.Group():
673
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
674
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
675
+
676
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
677
+
678
+ with gr.Group():
679
+ custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
680
+
681
+ with gr.Column(variant='panel'):
682
+
683
+ with gr.Group():
684
+ img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
685
+ img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
686
+
687
+ with gr.Group():
688
+ with gr.Row():
689
+ save = gr.Button('Save')
690
+ img2img_send_to_img2img = gr.Button('Send to img2img')
691
+ img2img_send_to_inpaint = gr.Button('Send to inpaint')
692
+ img2img_send_to_extras = gr.Button('Send to extras')
693
+ button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
694
+ open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
695
+
696
+ with gr.Group():
697
+ html_info = gr.HTML()
698
+ generation_info = gr.Textbox(visible=False)
699
+
700
+ connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
701
+ connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
702
+
703
+ mask_mode.change(
704
+ lambda mode, img: {
705
+ init_img_with_mask: gr_show(mode == 0),
706
+ init_img_inpaint: gr_show(mode == 1),
707
+ init_mask_inpaint: gr_show(mode == 1),
708
+ },
709
+ inputs=[mask_mode, init_img_with_mask],
710
+ outputs=[
711
+ init_img_with_mask,
712
+ init_img_inpaint,
713
+ init_mask_inpaint,
714
+ ],
715
+ )
716
+
717
+ img2img_args = dict(
718
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img),
719
+ _js="submit_img2img",
720
+ inputs=[
721
+ dummy_component,
722
+ img2img_prompt,
723
+ img2img_negative_prompt,
724
+ img2img_prompt_style,
725
+ img2img_prompt_style2,
726
+ init_img,
727
+ init_img_with_mask,
728
+ init_img_inpaint,
729
+ init_mask_inpaint,
730
+ mask_mode,
731
+ steps,
732
+ sampler_index,
733
+ mask_blur,
734
+ inpainting_fill,
735
+ restore_faces,
736
+ tiling,
737
+ batch_count,
738
+ batch_size,
739
+ cfg_scale,
740
+ denoising_strength,
741
+ seed,
742
+ subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
743
+ height,
744
+ width,
745
+ resize_mode,
746
+ inpaint_full_res,
747
+ inpaint_full_res_padding,
748
+ inpainting_mask_invert,
749
+ img2img_batch_input_dir,
750
+ img2img_batch_output_dir,
751
+ ] + custom_inputs,
752
+ outputs=[
753
+ img2img_gallery,
754
+ generation_info,
755
+ html_info
756
+ ],
757
+ show_progress=False,
758
+ )
759
+
760
+ img2img_prompt.submit(**img2img_args)
761
+ submit.click(**img2img_args)
762
+
763
+ img2img_interrogate.click(
764
+ fn=interrogate,
765
+ inputs=[init_img],
766
+ outputs=[img2img_prompt],
767
+ )
768
+
769
+ save.click(
770
+ fn=wrap_gradio_call(save_files),
771
+ _js="(x, y, z) => [x, y, selected_gallery_index()]",
772
+ inputs=[
773
+ generation_info,
774
+ img2img_gallery,
775
+ html_info
776
+ ],
777
+ outputs=[
778
+ html_info,
779
+ html_info,
780
+ html_info,
781
+ ]
782
+ )
783
+
784
+ roll.click(
785
+ fn=roll_artist,
786
+ _js="update_img2img_tokens",
787
+ inputs=[
788
+ img2img_prompt,
789
+ ],
790
+ outputs=[
791
+ img2img_prompt,
792
+ ]
793
+ )
794
+
795
+ prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
796
+ style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
797
+ style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
798
+
799
+ for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
800
+ button.click(
801
+ fn=add_style,
802
+ _js="ask_for_style_name",
803
+ # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
804
+ # the same number of parameters, but we only know the style-name after the JavaScript prompt
805
+ inputs=[dummy_component, prompt, negative_prompt],
806
+ outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
807
+ )
808
+
809
+ for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
810
+ button.click(
811
+ fn=apply_styles,
812
+ _js=js_func,
813
+ inputs=[prompt, negative_prompt, style1, style2],
814
+ outputs=[prompt, negative_prompt, style1, style2],
815
+ )
816
+
817
+ img2img_paste_fields = [
818
+ (img2img_prompt, "Prompt"),
819
+ (img2img_negative_prompt, "Negative prompt"),
820
+ (steps, "Steps"),
821
+ (sampler_index, "Sampler"),
822
+ (restore_faces, "Face restoration"),
823
+ (cfg_scale, "CFG scale"),
824
+ (seed, "Seed"),
825
+ (width, "Size-1"),
826
+ (height, "Size-2"),
827
+ (batch_size, "Batch size"),
828
+ (subseed, "Variation seed"),
829
+ (subseed_strength, "Variation seed strength"),
830
+ (seed_resize_from_w, "Seed resize from-1"),
831
+ (seed_resize_from_h, "Seed resize from-2"),
832
+ (denoising_strength, "Denoising strength"),
833
+ ]
834
+ modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
835
+ token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
836
+
837
+ with gr.Blocks(analytics_enabled=False) as extras_interface:
838
+ with gr.Row().style(equal_height=False):
839
+ with gr.Column(variant='panel'):
840
+ with gr.Tabs(elem_id="mode_extras"):
841
+ with gr.TabItem('Single Image'):
842
+ extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
843
+
844
+ with gr.TabItem('Batch Process'):
845
+ image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
846
+
847
+ upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
848
+
849
+ with gr.Group():
850
+ extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
851
+
852
+ with gr.Group():
853
+ extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
854
+ extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
855
+
856
+ with gr.Group():
857
+ gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan)
858
+
859
+ with gr.Group():
860
+ codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
861
+ codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
862
+
863
+ submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
864
+
865
+ with gr.Column(variant='panel'):
866
+ result_images = gr.Gallery(label="Result", show_label=False)
867
+ html_info_x = gr.HTML()
868
+ html_info = gr.HTML()
869
+ extras_send_to_img2img = gr.Button('Send to img2img')
870
+ extras_send_to_inpaint = gr.Button('Send to inpaint')
871
+ button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
872
+ open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
873
+
874
+ submit.click(
875
+ fn=wrap_gradio_gpu_call(modules.extras.run_extras),
876
+ _js="get_extras_tab_index",
877
+ inputs=[
878
+ dummy_component,
879
+ extras_image,
880
+ image_batch,
881
+ gfpgan_visibility,
882
+ codeformer_visibility,
883
+ codeformer_weight,
884
+ upscaling_resize,
885
+ extras_upscaler_1,
886
+ extras_upscaler_2,
887
+ extras_upscaler_2_visibility,
888
+ ],
889
+ outputs=[
890
+ result_images,
891
+ html_info_x,
892
+ html_info,
893
+ ]
894
+ )
895
+
896
+ extras_send_to_img2img.click(
897
+ fn=lambda x: image_from_url_text(x),
898
+ _js="extract_image_from_gallery_img2img",
899
+ inputs=[result_images],
900
+ outputs=[init_img],
901
+ )
902
+
903
+ extras_send_to_inpaint.click(
904
+ fn=lambda x: image_from_url_text(x),
905
+ _js="extract_image_from_gallery_img2img",
906
+ inputs=[result_images],
907
+ outputs=[init_img_with_mask],
908
+ )
909
+
910
+ with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
911
+ with gr.Row().style(equal_height=False):
912
+ with gr.Column(variant='panel'):
913
+ image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
914
+
915
+ with gr.Column(variant='panel'):
916
+ html = gr.HTML()
917
+ generation_info = gr.Textbox(visible=False)
918
+ html2 = gr.HTML()
919
+
920
+ with gr.Row():
921
+ pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
922
+ pnginfo_send_to_img2img = gr.Button('Send to img2img')
923
+
924
+ image.change(
925
+ fn=wrap_gradio_call(modules.extras.run_pnginfo),
926
+ inputs=[image],
927
+ outputs=[html, generation_info, html2],
928
+ )
929
+
930
+ with gr.Blocks() as modelmerger_interface:
931
+ with gr.Row().style(equal_height=False):
932
+ with gr.Column(variant='panel'):
933
+ gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
934
+
935
+ with gr.Row():
936
+ primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
937
+ secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
938
+ custom_name = gr.Textbox(label="Custom Name (Optional)")
939
+ interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
940
+ interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
941
+ save_as_half = gr.Checkbox(value=False, label="Safe as float16")
942
+ modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
943
+
944
+ with gr.Column(variant='panel'):
945
+ submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
946
+
947
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
948
+
949
+ with gr.Blocks() as textual_inversion_interface:
950
+ with gr.Row().style(equal_height=False):
951
+ with gr.Column():
952
+ with gr.Group():
953
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
954
+
955
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
956
+
957
+ new_embedding_name = gr.Textbox(label="Name")
958
+ initialization_text = gr.Textbox(label="Initialization text", value="*")
959
+ nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
960
+
961
+ with gr.Row():
962
+ with gr.Column(scale=3):
963
+ gr.HTML(value="")
964
+
965
+ with gr.Column():
966
+ create_embedding = gr.Button(value="Create", variant='primary')
967
+
968
+ with gr.Group():
969
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
970
+
971
+ process_src = gr.Textbox(label='Source directory')
972
+ process_dst = gr.Textbox(label='Destination directory')
973
+
974
+ with gr.Row():
975
+ process_flip = gr.Checkbox(label='Flip')
976
+ process_split = gr.Checkbox(label='Split into two')
977
+ process_caption = gr.Checkbox(label='Add caption')
978
+
979
+ with gr.Row():
980
+ with gr.Column(scale=3):
981
+ gr.HTML(value="")
982
+
983
+ with gr.Column():
984
+ run_preprocess = gr.Button(value="Preprocess", variant='primary')
985
+
986
+ with gr.Group():
987
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
988
+ train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
989
+ learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
990
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
991
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
992
+ template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
993
+ steps = gr.Number(label='Max steps', value=100000, precision=0)
994
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
995
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
996
+
997
+ with gr.Row():
998
+ with gr.Column(scale=2):
999
+ gr.HTML(value="")
1000
+
1001
+ with gr.Column():
1002
+ with gr.Row():
1003
+ interrupt_training = gr.Button(value="Interrupt")
1004
+ train_embedding = gr.Button(value="Train", variant='primary')
1005
+
1006
+ with gr.Column():
1007
+ progressbar = gr.HTML(elem_id="ti_progressbar")
1008
+ ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
1009
+
1010
+ ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
1011
+ ti_preview = gr.Image(elem_id='ti_preview', visible=False)
1012
+ ti_progress = gr.HTML(elem_id="ti_progress", value="")
1013
+ ti_outcome = gr.HTML(elem_id="ti_error", value="")
1014
+ setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
1015
+
1016
+ create_embedding.click(
1017
+ fn=modules.textual_inversion.ui.create_embedding,
1018
+ inputs=[
1019
+ new_embedding_name,
1020
+ initialization_text,
1021
+ nvpt,
1022
+ ],
1023
+ outputs=[
1024
+ train_embedding_name,
1025
+ ti_output,
1026
+ ti_outcome,
1027
+ ]
1028
+ )
1029
+
1030
+ run_preprocess.click(
1031
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
1032
+ _js="start_training_textual_inversion",
1033
+ inputs=[
1034
+ process_src,
1035
+ process_dst,
1036
+ process_flip,
1037
+ process_split,
1038
+ process_caption,
1039
+ ],
1040
+ outputs=[
1041
+ ti_output,
1042
+ ti_outcome,
1043
+ ],
1044
+ )
1045
+
1046
+ train_embedding.click(
1047
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
1048
+ _js="start_training_textual_inversion",
1049
+ inputs=[
1050
+ train_embedding_name,
1051
+ learn_rate,
1052
+ dataset_directory,
1053
+ log_directory,
1054
+ steps,
1055
+ create_image_every,
1056
+ save_embedding_every,
1057
+ template_file,
1058
+ ],
1059
+ outputs=[
1060
+ ti_output,
1061
+ ti_outcome,
1062
+ ]
1063
+ )
1064
+
1065
+ interrupt_training.click(
1066
+ fn=lambda: shared.state.interrupt(),
1067
+ inputs=[],
1068
+ outputs=[],
1069
+ )
1070
+
1071
+ def create_setting_component(key):
1072
+ def fun():
1073
+ return opts.data[key] if key in opts.data else opts.data_labels[key].default
1074
+
1075
+ info = opts.data_labels[key]
1076
+ t = type(info.default)
1077
+
1078
+ args = info.component_args() if callable(info.component_args) else info.component_args
1079
+
1080
+ if info.component is not None:
1081
+ comp = info.component
1082
+ elif t == str:
1083
+ comp = gr.Textbox
1084
+ elif t == int:
1085
+ comp = gr.Number
1086
+ elif t == bool:
1087
+ comp = gr.Checkbox
1088
+ else:
1089
+ raise Exception(f'bad options item type: {str(t)} for key {key}')
1090
+
1091
+ return comp(label=info.label, value=fun, **(args or {}))
1092
+
1093
+ components = []
1094
+ component_dict = {}
1095
+
1096
+ def open_folder(f):
1097
+ if not shared.cmd_opts.hide_ui_dir_config:
1098
+ path = os.path.normpath(f)
1099
+ if platform.system() == "Windows":
1100
+ os.startfile(path)
1101
+ elif platform.system() == "Darwin":
1102
+ sp.Popen(["open", path])
1103
+ else:
1104
+ sp.Popen(["xdg-open", path])
1105
+
1106
+ def run_settings(*args):
1107
+ changed = 0
1108
+
1109
+ for key, value, comp in zip(opts.data_labels.keys(), args, components):
1110
+ if not opts.same_type(value, opts.data_labels[key].default):
1111
+ return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
1112
+
1113
+ for key, value, comp in zip(opts.data_labels.keys(), args, components):
1114
+ comp_args = opts.data_labels[key].component_args
1115
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
1116
+ continue
1117
+
1118
+ oldval = opts.data.get(key, None)
1119
+ opts.data[key] = value
1120
+
1121
+ if oldval != value:
1122
+ if opts.data_labels[key].onchange is not None:
1123
+ opts.data_labels[key].onchange()
1124
+
1125
+ changed += 1
1126
+
1127
+ opts.save(shared.config_filename)
1128
+
1129
+ return f'{changed} settings changed.', opts.dumpjson()
1130
+
1131
+ with gr.Blocks(analytics_enabled=False) as settings_interface:
1132
+ settings_submit = gr.Button(value="Apply settings", variant='primary')
1133
+ result = gr.HTML()
1134
+
1135
+ settings_cols = 3
1136
+ items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
1137
+
1138
+ cols_displayed = 0
1139
+ items_displayed = 0
1140
+ previous_section = None
1141
+ column = None
1142
+ with gr.Row(elem_id="settings").style(equal_height=False):
1143
+ for i, (k, item) in enumerate(opts.data_labels.items()):
1144
+
1145
+ if previous_section != item.section:
1146
+ if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
1147
+ if column is not None:
1148
+ column.__exit__()
1149
+
1150
+ column = gr.Column(variant='panel')
1151
+ column.__enter__()
1152
+
1153
+ items_displayed = 0
1154
+ cols_displayed += 1
1155
+
1156
+ previous_section = item.section
1157
+
1158
+ gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
1159
+
1160
+ component = create_setting_component(k)
1161
+ component_dict[k] = component
1162
+ components.append(component)
1163
+ items_displayed += 1
1164
+
1165
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
1166
+ request_notifications.click(
1167
+ fn=lambda: None,
1168
+ inputs=[],
1169
+ outputs=[],
1170
+ _js='function(){}'
1171
+ )
1172
+
1173
+ with gr.Row():
1174
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
1175
+ restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
1176
+
1177
+
1178
+ def reload_scripts():
1179
+ modules.scripts.reload_script_body_only()
1180
+
1181
+ reload_script_bodies.click(
1182
+ fn=reload_scripts,
1183
+ inputs=[],
1184
+ outputs=[],
1185
+ _js='function(){}'
1186
+ )
1187
+
1188
+ def request_restart():
1189
+ shared.state.interrupt()
1190
+ settings_interface.gradio_ref.do_restart = True
1191
+
1192
+ restart_gradio.click(
1193
+ fn=request_restart,
1194
+ inputs=[],
1195
+ outputs=[],
1196
+ _js='function(){restart_reload()}'
1197
+ )
1198
+
1199
+ if column is not None:
1200
+ column.__exit__()
1201
+
1202
+ interfaces = [
1203
+ (txt2img_interface, "txt2img", "txt2img"),
1204
+ (img2img_interface, "img2img", "img2img"),
1205
+ (extras_interface, "Extras", "extras"),
1206
+ (pnginfo_interface, "PNG Info", "pnginfo"),
1207
+ (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
1208
+ (textual_inversion_interface, "Textual inversion", "ti"),
1209
+ (settings_interface, "Settings", "settings"),
1210
+ ]
1211
+
1212
+ with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
1213
+ css = file.read()
1214
+
1215
+ if os.path.exists(os.path.join(script_path, "user.css")):
1216
+ with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
1217
+ usercss = file.read()
1218
+ css += usercss
1219
+
1220
+ if not cmd_opts.no_progressbar_hiding:
1221
+ css += css_hide_progressbar
1222
+
1223
+ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
1224
+
1225
+ settings_interface.gradio_ref = demo
1226
+
1227
+ with gr.Tabs() as tabs:
1228
+ for interface, label, ifid in interfaces:
1229
+ with gr.TabItem(label, id=ifid):
1230
+ interface.render()
1231
+
1232
+ if os.path.exists(os.path.join(script_path, "notification.mp3")):
1233
+ audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
1234
+
1235
+ text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
1236
+ settings_submit.click(
1237
+ fn=run_settings,
1238
+ inputs=components,
1239
+ outputs=[result, text_settings],
1240
+ )
1241
+
1242
+ def modelmerger(*args):
1243
+ try:
1244
+ results = modules.extras.run_modelmerger(*args)
1245
+ except Exception as e:
1246
+ print("Error loading/saving model file:", file=sys.stderr)
1247
+ print(traceback.format_exc(), file=sys.stderr)
1248
+ modules.sd_models.list_models() # to remove the potentially missing models from the list
1249
+ return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
1250
+ return results
1251
+
1252
+ modelmerger_merge.click(
1253
+ fn=modelmerger,
1254
+ inputs=[
1255
+ primary_model_name,
1256
+ secondary_model_name,
1257
+ interp_method,
1258
+ interp_amount,
1259
+ save_as_half,
1260
+ custom_name,
1261
+ ],
1262
+ outputs=[
1263
+ submit_result,
1264
+ primary_model_name,
1265
+ secondary_model_name,
1266
+ component_dict['sd_model_checkpoint'],
1267
+ ]
1268
+ )
1269
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
1270
+ txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
1271
+ img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
1272
+ send_to_img2img.click(
1273
+ fn=lambda img, *args: (image_from_url_text(img),*args),
1274
+ _js="(gallery, ...args) => [extract_image_from_gallery_img2img(gallery), ...args]",
1275
+ inputs=[txt2img_gallery] + txt2img_fields,
1276
+ outputs=[init_img] + img2img_fields,
1277
+ )
1278
+
1279
+ send_to_inpaint.click(
1280
+ fn=lambda x, *args: (image_from_url_text(x), *args),
1281
+ _js="(gallery, ...args) => [extract_image_from_gallery_inpaint(gallery), ...args]",
1282
+ inputs=[txt2img_gallery] + txt2img_fields,
1283
+ outputs=[init_img_with_mask] + img2img_fields,
1284
+ )
1285
+
1286
+ img2img_send_to_img2img.click(
1287
+ fn=lambda x: image_from_url_text(x),
1288
+ _js="extract_image_from_gallery_img2img",
1289
+ inputs=[img2img_gallery],
1290
+ outputs=[init_img],
1291
+ )
1292
+
1293
+ img2img_send_to_inpaint.click(
1294
+ fn=lambda x: image_from_url_text(x),
1295
+ _js="extract_image_from_gallery_inpaint",
1296
+ inputs=[img2img_gallery],
1297
+ outputs=[init_img_with_mask],
1298
+ )
1299
+
1300
+ send_to_extras.click(
1301
+ fn=lambda x: image_from_url_text(x),
1302
+ _js="extract_image_from_gallery_extras",
1303
+ inputs=[txt2img_gallery],
1304
+ outputs=[extras_image],
1305
+ )
1306
+
1307
+ open_txt2img_folder.click(
1308
+ fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
1309
+ inputs=[],
1310
+ outputs=[],
1311
+ )
1312
+
1313
+ open_img2img_folder.click(
1314
+ fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
1315
+ inputs=[],
1316
+ outputs=[],
1317
+ )
1318
+
1319
+ open_extras_folder.click(
1320
+ fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
1321
+ inputs=[],
1322
+ outputs=[],
1323
+ )
1324
+
1325
+ img2img_send_to_extras.click(
1326
+ fn=lambda x: image_from_url_text(x),
1327
+ _js="extract_image_from_gallery_extras",
1328
+ inputs=[img2img_gallery],
1329
+ outputs=[extras_image],
1330
+ )
1331
+
1332
+ modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields, generation_info, 'switch_to_txt2img')
1333
+ modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields, generation_info, 'switch_to_img2img_img2img')
1334
+
1335
+ ui_config_file = cmd_opts.ui_config_file
1336
+ ui_settings = {}
1337
+ settings_count = len(ui_settings)
1338
+ error_loading = False
1339
+
1340
+ try:
1341
+ if os.path.exists(ui_config_file):
1342
+ with open(ui_config_file, "r", encoding="utf8") as file:
1343
+ ui_settings = json.load(file)
1344
+ except Exception:
1345
+ error_loading = True
1346
+ print("Error loading settings:", file=sys.stderr)
1347
+ print(traceback.format_exc(), file=sys.stderr)
1348
+
1349
+ def loadsave(path, x):
1350
+ def apply_field(obj, field, condition=None):
1351
+ key = path + "/" + field
1352
+
1353
+ if getattr(obj,'custom_script_source',None) is not None:
1354
+ key = 'customscript/' + obj.custom_script_source + '/' + key
1355
+
1356
+ if getattr(obj, 'do_not_save_to_config', False):
1357
+ return
1358
+
1359
+ saved_value = ui_settings.get(key, None)
1360
+ if saved_value is None:
1361
+ ui_settings[key] = getattr(obj, field)
1362
+ elif condition is None or condition(saved_value):
1363
+ setattr(obj, field, saved_value)
1364
+
1365
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
1366
+ apply_field(x, 'visible')
1367
+
1368
+ if type(x) == gr.Slider:
1369
+ apply_field(x, 'value')
1370
+ apply_field(x, 'minimum')
1371
+ apply_field(x, 'maximum')
1372
+ apply_field(x, 'step')
1373
+
1374
+ if type(x) == gr.Radio:
1375
+ apply_field(x, 'value', lambda val: val in x.choices)
1376
+
1377
+ if type(x) == gr.Checkbox:
1378
+ apply_field(x, 'value')
1379
+
1380
+ if type(x) == gr.Textbox:
1381
+ apply_field(x, 'value')
1382
+
1383
+ if type(x) == gr.Number:
1384
+ apply_field(x, 'value')
1385
+
1386
+ visit(txt2img_interface, loadsave, "txt2img")
1387
+ visit(img2img_interface, loadsave, "img2img")
1388
+ visit(extras_interface, loadsave, "extras")
1389
+
1390
+ if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
1391
+ with open(ui_config_file, "w", encoding="utf8") as file:
1392
+ json.dump(ui_settings, file, indent=4)
1393
+
1394
+ return demo
1395
+
1396
+
1397
+ with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
1398
+ javascript = f'<script>{jsfile.read()}</script>'
1399
+
1400
+ jsdir = os.path.join(script_path, "javascript")
1401
+ for filename in sorted(os.listdir(jsdir)):
1402
+ with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
1403
+ javascript += f"\n<script>{jsfile.read()}</script>"
1404
+
1405
+
1406
+ if 'gradio_routes_templates_response' not in globals():
1407
+ def template_response(*args, **kwargs):
1408
+ res = gradio_routes_templates_response(*args, **kwargs)
1409
+ res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
1410
+ res.init_headers()
1411
+ return res
1412
+
1413
+ gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
1414
+ gradio.routes.templates.TemplateResponse = template_response
modules/upscaler.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import abstractmethod
3
+
4
+ import PIL
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+
9
+ import modules.shared
10
+ from modules import modelloader, shared
11
+
12
+ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
13
+ from modules.paths import models_path
14
+
15
+
16
+ class Upscaler:
17
+ name = None
18
+ model_path = None
19
+ model_name = None
20
+ model_url = None
21
+ enable = True
22
+ filter = None
23
+ model = None
24
+ user_path = None
25
+ scalers: []
26
+ tile = True
27
+
28
+ def __init__(self, create_dirs=False):
29
+ self.mod_pad_h = None
30
+ self.tile_size = modules.shared.opts.ESRGAN_tile
31
+ self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
32
+ self.device = modules.shared.device
33
+ self.img = None
34
+ self.output = None
35
+ self.scale = 1
36
+ self.half = not modules.shared.cmd_opts.no_half
37
+ self.pre_pad = 0
38
+ self.mod_scale = None
39
+ if self.name is not None and create_dirs:
40
+ self.model_path = os.path.join(models_path, self.name)
41
+ if not os.path.exists(self.model_path):
42
+ os.makedirs(self.model_path)
43
+
44
+ try:
45
+ import cv2
46
+ self.can_tile = True
47
+ except:
48
+ pass
49
+
50
+ @abstractmethod
51
+ def do_upscale(self, img: PIL.Image, selected_model: str):
52
+ return img
53
+
54
+ def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
55
+ self.scale = scale
56
+ dest_w = img.width * scale
57
+ dest_h = img.height * scale
58
+ for i in range(3):
59
+ if img.width >= dest_w and img.height >= dest_h:
60
+ break
61
+ img = self.do_upscale(img, selected_model)
62
+ if img.width != dest_w or img.height != dest_h:
63
+ img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
64
+
65
+ return img
66
+
67
+ @abstractmethod
68
+ def load_model(self, path: str):
69
+ pass
70
+
71
+ def find_models(self, ext_filter=None) -> list:
72
+ return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
73
+
74
+ def update_status(self, prompt):
75
+ print(f"\nextras: {prompt}", file=shared.progress_print_out)
76
+
77
+
78
+ class UpscalerData:
79
+ name = None
80
+ data_path = None
81
+ scale: int = 4
82
+ scaler: Upscaler = None
83
+ model: None
84
+
85
+ def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
86
+ self.name = name
87
+ self.data_path = path
88
+ self.scaler = upscaler
89
+ self.scale = scale
90
+ self.model = model
91
+
92
+
93
+ class UpscalerNone(Upscaler):
94
+ name = "None"
95
+ scalers = []
96
+
97
+ def load_model(self, path):
98
+ pass
99
+
100
+ def do_upscale(self, img, selected_model=None):
101
+ return img
102
+
103
+ def __init__(self, dirname=None):
104
+ super().__init__(False)
105
+ self.scalers = [UpscalerData("None", None, self)]
106
+
107
+
108
+ class UpscalerLanczos(Upscaler):
109
+ scalers = []
110
+
111
+ def do_upscale(self, img, selected_model=None):
112
+ return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
113
+
114
+ def load_model(self, _):
115
+ pass
116
+
117
+ def __init__(self, dirname=None):
118
+ super().__init__(False)
119
+ self.name = "Lanczos"
120
+ self.scalers = [UpscalerData("Lanczos", None, self)]
121
+