# 外部から簡単にupscalerを呼ぶためのスクリプト # 単体で動くようにモデル定義も含めている import argparse import glob import os import cv2 from diffusers import AutoencoderKL from typing import Dict, List import numpy as np import torch from torch import nn from tqdm import tqdm from PIL import Image class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): super(ResidualBlock, self).__init__() if out_channels is None: out_channels = in_channels self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも # initialize weights self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu1(out) out = self.conv2(out) out = self.bn2(out) out += residual out = self.relu2(out) return out class Upscaler(nn.Module): def __init__(self): super(Upscaler, self).__init__() # define layers # latent has 4 channels self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.bn1 = nn.BatchNorm2d(128) self.relu1 = nn.ReLU(inplace=True) # resblocks # 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ self.resblock1 = ResidualBlock(128) self.resblock2 = ResidualBlock(128) self.resblock3 = ResidualBlock(128) self.resblock4 = ResidualBlock(128) self.resblock5 = ResidualBlock(128) self.resblock6 = ResidualBlock(128) self.resblock7 = ResidualBlock(128) self.resblock8 = ResidualBlock(128) self.resblock9 = ResidualBlock(128) self.resblock10 = ResidualBlock(128) self.resblock11 = ResidualBlock(128) self.resblock12 = ResidualBlock(128) self.resblock13 = ResidualBlock(128) self.resblock14 = ResidualBlock(128) self.resblock15 = ResidualBlock(128) self.resblock16 = ResidualBlock(128) self.resblock17 = ResidualBlock(128) self.resblock18 = ResidualBlock(128) self.resblock19 = ResidualBlock(128) self.resblock20 = ResidualBlock(128) # last convs self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.bn2 = nn.BatchNorm2d(64) self.relu2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.bn3 = nn.BatchNorm2d(64) self.relu3 = nn.ReLU(inplace=True) # final conv: output 4 channels self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)) # initialize weights self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) # initialize final conv weights to 0: 流行りのzero conv nn.init.constant_(self.conv_final.weight, 0) def forward(self, x): inp = x x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) # いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず residual = x x = self.resblock1(x) x = self.resblock2(x) x = self.resblock3(x) x = self.resblock4(x) x = x + residual residual = x x = self.resblock5(x) x = self.resblock6(x) x = self.resblock7(x) x = self.resblock8(x) x = x + residual residual = x x = self.resblock9(x) x = self.resblock10(x) x = self.resblock11(x) x = self.resblock12(x) x = x + residual residual = x x = self.resblock13(x) x = self.resblock14(x) x = self.resblock15(x) x = self.resblock16(x) x = x + residual residual = x x = self.resblock17(x) x = self.resblock18(x) x = self.resblock19(x) x = self.resblock20(x) x = x + residual x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) x = self.conv3(x) x = self.bn3(x) # ここにreluを入れないほうがいい気がする x = self.conv_final(x) # network estimates the difference between the input and the output x = x + inp return x def support_latents(self) -> bool: return False def upscale( self, vae: AutoencoderKL, lowreso_images: List[Image.Image], lowreso_latents: torch.Tensor, dtype: torch.dtype, width: int, height: int, batch_size: int = 1, vae_batch_size: int = 1, ): # assertion assert lowreso_images is not None, "Upscaler requires lowreso image" # make upsampled image with lanczos4 upsampled_images = [] for lowreso_image in lowreso_images: upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS)) upsampled_images.append(upsampled_image) # convert to tensor: this tensor is too large to be converted to cuda upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images] upsampled_images = torch.stack(upsampled_images, dim=0) upsampled_images = upsampled_images.to(dtype) # normalize to [-1, 1] upsampled_images = upsampled_images / 127.5 - 1.0 # convert upsample images to latents with batch size # print("Encoding upsampled (LANCZOS4) images...") upsampled_latents = [] for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)): batch = upsampled_images[i : i + vae_batch_size].to(vae.device) with torch.no_grad(): batch = vae.encode(batch).latent_dist.sample() upsampled_latents.append(batch) upsampled_latents = torch.cat(upsampled_latents, dim=0) # upscale (refine) latents with this model with batch size print("Upscaling latents...") upscaled_latents = [] for i in range(0, upsampled_latents.shape[0], batch_size): with torch.no_grad(): upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size])) upscaled_latents = torch.cat(upscaled_latents, dim=0) return upscaled_latents * 0.18215 # external interface: returns a model def create_upscaler(**kwargs): weights = kwargs["weights"] model = Upscaler() print(f"Loading weights from {weights}...") if os.path.splitext(weights)[1] == ".safetensors": from safetensors.torch import load_file sd = load_file(weights) else: sd = torch.load(weights, map_location=torch.device("cpu")) model.load_state_dict(sd) return model # another interface: upscale images with a model for given images from command line def upscale_images(args: argparse.Namespace): DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") us_dtype = torch.float16 # TODO: support fp32/bf16 os.makedirs(args.output_dir, exist_ok=True) # load VAE with Diffusers assert args.vae_path is not None, "VAE path is required" print(f"Loading VAE from {args.vae_path}...") vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") vae.to(DEVICE, dtype=us_dtype) # prepare model print("Preparing model...") upscaler: Upscaler = create_upscaler(weights=args.weights) # print("Loading weights from", args.weights) # upscaler.load_state_dict(torch.load(args.weights)) upscaler.eval() upscaler.to(DEVICE, dtype=us_dtype) # load images image_paths = glob.glob(args.image_pattern) images = [] for image_path in image_paths: image = Image.open(image_path) image = image.convert("RGB") # make divisible by 8 width = image.width height = image.height if width % 8 != 0: width = width - (width % 8) if height % 8 != 0: height = height - (height % 8) if width != image.width or height != image.height: image = image.crop((0, 0, width, height)) images.append(image) # debug output if args.debug: for image, image_path in zip(images, image_paths): image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS) basename = os.path.basename(image_path) basename_wo_ext, ext = os.path.splitext(basename) dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}") image_debug.save(dest_file_name) # upscale print("Upscaling...") upscaled_latents = upscaler.upscale( vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size ) upscaled_latents /= 0.18215 # decode with batch print("Decoding...") upscaled_images = [] for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)): with torch.no_grad(): batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample batch = batch.to("cpu") upscaled_images.append(batch) upscaled_images = torch.cat(upscaled_images, dim=0) # tensor to numpy upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy() upscaled_images = (upscaled_images + 1.0) * 127.5 upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8) upscaled_images = upscaled_images[..., ::-1] # save images for i, image in enumerate(upscaled_images): basename = os.path.basename(image_paths[i]) basename_wo_ext, ext = os.path.splitext(basename) dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}") cv2.imwrite(dest_file_name, image) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--vae_path", type=str, default=None, help="VAE path") parser.add_argument("--weights", type=str, default=None, help="Weights path") parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern") parser.add_argument("--output_dir", type=str, default=".", help="Output directory") parser.add_argument("--batch_size", type=int, default=4, help="Batch size") parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size") parser.add_argument("--debug", action="store_true", help="Debug mode") args = parser.parse_args() upscale_images(args)