kohya_ss / tools /latent_upscaler.py
Ateras's picture
Upload folder using huggingface_hub
fe6327d
# 外部から簡単にupscalerを呼ぶためのスクリプト
# 単体で動くようにモデル定義も含めている
import argparse
import glob
import os
import cv2
from diffusers import AutoencoderKL
from typing import Dict, List
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from PIL import Image
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
super(ResidualBlock, self).__init__()
if out_channels is None:
out_channels = in_channels
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも
# initialize weights
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu2(out)
return out
class Upscaler(nn.Module):
def __init__(self):
super(Upscaler, self).__init__()
# define layers
# latent has 4 channels
self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(128)
self.relu1 = nn.ReLU(inplace=True)
# resblocks
# 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
self.resblock1 = ResidualBlock(128)
self.resblock2 = ResidualBlock(128)
self.resblock3 = ResidualBlock(128)
self.resblock4 = ResidualBlock(128)
self.resblock5 = ResidualBlock(128)
self.resblock6 = ResidualBlock(128)
self.resblock7 = ResidualBlock(128)
self.resblock8 = ResidualBlock(128)
self.resblock9 = ResidualBlock(128)
self.resblock10 = ResidualBlock(128)
self.resblock11 = ResidualBlock(128)
self.resblock12 = ResidualBlock(128)
self.resblock13 = ResidualBlock(128)
self.resblock14 = ResidualBlock(128)
self.resblock15 = ResidualBlock(128)
self.resblock16 = ResidualBlock(128)
self.resblock17 = ResidualBlock(128)
self.resblock18 = ResidualBlock(128)
self.resblock19 = ResidualBlock(128)
self.resblock20 = ResidualBlock(128)
# last convs
self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn3 = nn.BatchNorm2d(64)
self.relu3 = nn.ReLU(inplace=True)
# final conv: output 4 channels
self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
# initialize weights
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
# initialize final conv weights to 0: 流行りのzero conv
nn.init.constant_(self.conv_final.weight, 0)
def forward(self, x):
inp = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
# いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず
residual = x
x = self.resblock1(x)
x = self.resblock2(x)
x = self.resblock3(x)
x = self.resblock4(x)
x = x + residual
residual = x
x = self.resblock5(x)
x = self.resblock6(x)
x = self.resblock7(x)
x = self.resblock8(x)
x = x + residual
residual = x
x = self.resblock9(x)
x = self.resblock10(x)
x = self.resblock11(x)
x = self.resblock12(x)
x = x + residual
residual = x
x = self.resblock13(x)
x = self.resblock14(x)
x = self.resblock15(x)
x = self.resblock16(x)
x = x + residual
residual = x
x = self.resblock17(x)
x = self.resblock18(x)
x = self.resblock19(x)
x = self.resblock20(x)
x = x + residual
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.bn3(x)
# ここにreluを入れないほうがいい気がする
x = self.conv_final(x)
# network estimates the difference between the input and the output
x = x + inp
return x
def support_latents(self) -> bool:
return False
def upscale(
self,
vae: AutoencoderKL,
lowreso_images: List[Image.Image],
lowreso_latents: torch.Tensor,
dtype: torch.dtype,
width: int,
height: int,
batch_size: int = 1,
vae_batch_size: int = 1,
):
# assertion
assert lowreso_images is not None, "Upscaler requires lowreso image"
# make upsampled image with lanczos4
upsampled_images = []
for lowreso_image in lowreso_images:
upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
upsampled_images.append(upsampled_image)
# convert to tensor: this tensor is too large to be converted to cuda
upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
upsampled_images = torch.stack(upsampled_images, dim=0)
upsampled_images = upsampled_images.to(dtype)
# normalize to [-1, 1]
upsampled_images = upsampled_images / 127.5 - 1.0
# convert upsample images to latents with batch size
# print("Encoding upsampled (LANCZOS4) images...")
upsampled_latents = []
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
with torch.no_grad():
batch = vae.encode(batch).latent_dist.sample()
upsampled_latents.append(batch)
upsampled_latents = torch.cat(upsampled_latents, dim=0)
# upscale (refine) latents with this model with batch size
print("Upscaling latents...")
upscaled_latents = []
for i in range(0, upsampled_latents.shape[0], batch_size):
with torch.no_grad():
upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
upscaled_latents = torch.cat(upscaled_latents, dim=0)
return upscaled_latents * 0.18215
# external interface: returns a model
def create_upscaler(**kwargs):
weights = kwargs["weights"]
model = Upscaler()
print(f"Loading weights from {weights}...")
if os.path.splitext(weights)[1] == ".safetensors":
from safetensors.torch import load_file
sd = load_file(weights)
else:
sd = torch.load(weights, map_location=torch.device("cpu"))
model.load_state_dict(sd)
return model
# another interface: upscale images with a model for given images from command line
def upscale_images(args: argparse.Namespace):
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
us_dtype = torch.float16 # TODO: support fp32/bf16
os.makedirs(args.output_dir, exist_ok=True)
# load VAE with Diffusers
assert args.vae_path is not None, "VAE path is required"
print(f"Loading VAE from {args.vae_path}...")
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.to(DEVICE, dtype=us_dtype)
# prepare model
print("Preparing model...")
upscaler: Upscaler = create_upscaler(weights=args.weights)
# print("Loading weights from", args.weights)
# upscaler.load_state_dict(torch.load(args.weights))
upscaler.eval()
upscaler.to(DEVICE, dtype=us_dtype)
# load images
image_paths = glob.glob(args.image_pattern)
images = []
for image_path in image_paths:
image = Image.open(image_path)
image = image.convert("RGB")
# make divisible by 8
width = image.width
height = image.height
if width % 8 != 0:
width = width - (width % 8)
if height % 8 != 0:
height = height - (height % 8)
if width != image.width or height != image.height:
image = image.crop((0, 0, width, height))
images.append(image)
# debug output
if args.debug:
for image, image_path in zip(images, image_paths):
image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
basename = os.path.basename(image_path)
basename_wo_ext, ext = os.path.splitext(basename)
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
image_debug.save(dest_file_name)
# upscale
print("Upscaling...")
upscaled_latents = upscaler.upscale(
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
)
upscaled_latents /= 0.18215
# decode with batch
print("Decoding...")
upscaled_images = []
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
with torch.no_grad():
batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
batch = batch.to("cpu")
upscaled_images.append(batch)
upscaled_images = torch.cat(upscaled_images, dim=0)
# tensor to numpy
upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
upscaled_images = (upscaled_images + 1.0) * 127.5
upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
upscaled_images = upscaled_images[..., ::-1]
# save images
for i, image in enumerate(upscaled_images):
basename = os.path.basename(image_paths[i])
basename_wo_ext, ext = os.path.splitext(basename)
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
cv2.imwrite(dest_file_name, image)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
parser.add_argument("--weights", type=str, default=None, help="Weights path")
parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
parser.add_argument("--debug", action="store_true", help="Debug mode")
args = parser.parse_args()
upscale_images(args)