|
''' |
|
# ------------------------------------------------------------------------ |
|
# |
|
# Tiled VAE |
|
# |
|
# Introducing a revolutionary new optimization designed to make |
|
# the VAE work with giant images on limited VRAM! |
|
# Say goodbye to the frustration of OOM and hello to seamless output! |
|
# |
|
# ------------------------------------------------------------------------ |
|
# |
|
# This script is a wild hack that splits the image into tiles, |
|
# encodes each tile separately, and merges the result back together. |
|
# |
|
# Advantages: |
|
# - The VAE can now work with giant images on limited VRAM |
|
# (~10 GB for 8K images!) |
|
# - The merged output is completely seamless without any post-processing. |
|
# |
|
# Drawbacks: |
|
# - NaNs always appear in for 8k images when you use fp16 (half) VAE |
|
# You must use --no-half-vae to disable half VAE for that giant image. |
|
# - The gradient calculation is not compatible with this hack. It |
|
# will break any backward() or torch.autograd.grad() that passes VAE. |
|
# (But you can still use the VAE to generate training data.) |
|
# |
|
# How it works: |
|
# 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder. |
|
# 2. When Fast Mode is disabled: |
|
# 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile. |
|
# 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile. |
|
# 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues. |
|
# 4. A zigzag execution order is used to reduce unnecessary data transfer. |
|
# 3. When Fast Mode is enabled: |
|
# 1. The original input is downsampled and passed to a separate task queue. |
|
# 2. Its group norm parameters are recorded and used by all tiles' task queues. |
|
# 3. Each tile is separately processed without any RAM-VRAM data transfer. |
|
# 4. After all tiles are processed, tiles are written to a result buffer and returned. |
|
# Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode. |
|
# |
|
# Enjoy! |
|
# |
|
# @Author: LI YI @ Nanyang Technological University - Singapore |
|
# @Date: 2023-03-02 |
|
# @License: CC BY-NC-SA 4.0 |
|
# |
|
# Please give me a star if you like this project! |
|
# |
|
# ------------------------------------------------------------------------- |
|
''' |
|
|
|
import gc |
|
import math |
|
from time import time |
|
from tqdm import tqdm |
|
|
|
import torch |
|
import torch.version |
|
import torch.nn.functional as F |
|
import gradio as gr |
|
|
|
import modules.scripts as scripts |
|
import modules.devices as devices |
|
from modules.shared import state |
|
from modules.ui import gr_show |
|
from modules.processing import opt_f |
|
from modules.sd_vae_approx import cheap_approximation |
|
from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock |
|
|
|
from tile_utils.attn import get_attn_func |
|
from tile_utils.typing import Processing |
|
|
|
|
|
def get_rcmd_enc_tsize(): |
|
if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]: |
|
total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20 |
|
if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072 |
|
elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048 |
|
elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536 |
|
else: ENCODER_TILE_SIZE = 960 |
|
else: ENCODER_TILE_SIZE = 512 |
|
return ENCODER_TILE_SIZE |
|
|
|
|
|
def get_rcmd_dec_tsize(): |
|
if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]: |
|
total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20 |
|
if total_memory > 30*1000: DECODER_TILE_SIZE = 256 |
|
elif total_memory > 16*1000: DECODER_TILE_SIZE = 192 |
|
elif total_memory > 12*1000: DECODER_TILE_SIZE = 128 |
|
elif total_memory > 8*1000: DECODER_TILE_SIZE = 96 |
|
else: DECODER_TILE_SIZE = 64 |
|
else: DECODER_TILE_SIZE = 64 |
|
return DECODER_TILE_SIZE |
|
|
|
|
|
def inplace_nonlinearity(x): |
|
|
|
return F.silu(x, inplace=True) |
|
|
|
|
|
def attn2task(task_queue, net): |
|
attn_forward = get_attn_func() |
|
task_queue.append(('store_res', lambda x: x)) |
|
task_queue.append(('pre_norm', net.norm)) |
|
task_queue.append(('attn', lambda x, net=net: attn_forward(net, x))) |
|
task_queue.append(['add_res', None]) |
|
|
|
|
|
def resblock2task(queue, block): |
|
""" |
|
Turn a ResNetBlock into a sequence of tasks and append to the task queue |
|
|
|
@param queue: the target task queue |
|
@param block: ResNetBlock |
|
|
|
""" |
|
if block.in_channels != block.out_channels: |
|
if block.use_conv_shortcut: |
|
queue.append(('store_res', block.conv_shortcut)) |
|
else: |
|
queue.append(('store_res', block.nin_shortcut)) |
|
else: |
|
queue.append(('store_res', lambda x: x)) |
|
queue.append(('pre_norm', block.norm1)) |
|
queue.append(('silu', inplace_nonlinearity)) |
|
queue.append(('conv1', block.conv1)) |
|
queue.append(('pre_norm', block.norm2)) |
|
queue.append(('silu', inplace_nonlinearity)) |
|
queue.append(('conv2', block.conv2)) |
|
queue.append(['add_res', None]) |
|
|
|
|
|
def build_sampling(task_queue, net, is_decoder): |
|
""" |
|
Build the sampling part of a task queue |
|
@param task_queue: the target task queue |
|
@param net: the network |
|
@param is_decoder: currently building decoder or encoder |
|
""" |
|
if is_decoder: |
|
resblock2task(task_queue, net.mid.block_1) |
|
attn2task(task_queue, net.mid.attn_1) |
|
resblock2task(task_queue, net.mid.block_2) |
|
resolution_iter = reversed(range(net.num_resolutions)) |
|
block_ids = net.num_res_blocks + 1 |
|
condition = 0 |
|
module = net.up |
|
func_name = 'upsample' |
|
else: |
|
resolution_iter = range(net.num_resolutions) |
|
block_ids = net.num_res_blocks |
|
condition = net.num_resolutions - 1 |
|
module = net.down |
|
func_name = 'downsample' |
|
|
|
for i_level in resolution_iter: |
|
for i_block in range(block_ids): |
|
resblock2task(task_queue, module[i_level].block[i_block]) |
|
if i_level != condition: |
|
task_queue.append((func_name, getattr(module[i_level], func_name))) |
|
|
|
if not is_decoder: |
|
resblock2task(task_queue, net.mid.block_1) |
|
attn2task(task_queue, net.mid.attn_1) |
|
resblock2task(task_queue, net.mid.block_2) |
|
|
|
|
|
def build_task_queue(net, is_decoder): |
|
""" |
|
Build a single task queue for the encoder or decoder |
|
@param net: the VAE decoder or encoder network |
|
@param is_decoder: currently building decoder or encoder |
|
@return: the task queue |
|
""" |
|
task_queue = [] |
|
task_queue.append(('conv_in', net.conv_in)) |
|
|
|
|
|
|
|
build_sampling(task_queue, net, is_decoder) |
|
|
|
if not is_decoder or not net.give_pre_end: |
|
task_queue.append(('pre_norm', net.norm_out)) |
|
task_queue.append(('silu', inplace_nonlinearity)) |
|
task_queue.append(('conv_out', net.conv_out)) |
|
if is_decoder and net.tanh_out: |
|
task_queue.append(('tanh', torch.tanh)) |
|
|
|
return task_queue |
|
|
|
|
|
def clone_task_queue(task_queue): |
|
""" |
|
Clone a task queue |
|
@param task_queue: the task queue to be cloned |
|
@return: the cloned task queue |
|
""" |
|
return [[item for item in task] for task in task_queue] |
|
|
|
|
|
def get_var_mean(input, num_groups, eps=1e-6): |
|
""" |
|
Get mean and var for group norm |
|
""" |
|
b, c = input.size(0), input.size(1) |
|
channel_in_group = int(c/num_groups) |
|
input_reshaped = input.contiguous().view(1, int(b * num_groups), channel_in_group, *input.size()[2:]) |
|
var, mean = torch.var_mean(input_reshaped, dim=[0, 2, 3, 4], unbiased=False) |
|
return var, mean |
|
|
|
|
|
def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6): |
|
""" |
|
Custom group norm with fixed mean and var |
|
|
|
@param input: input tensor |
|
@param num_groups: number of groups. by default, num_groups = 32 |
|
@param mean: mean, must be pre-calculated by get_var_mean |
|
@param var: var, must be pre-calculated by get_var_mean |
|
@param weight: weight, should be fetched from the original group norm |
|
@param bias: bias, should be fetched from the original group norm |
|
@param eps: epsilon, by default, eps = 1e-6 to match the original group norm |
|
|
|
@return: normalized tensor |
|
""" |
|
b, c = input.size(0), input.size(1) |
|
channel_in_group = int(c/num_groups) |
|
input_reshaped = input.contiguous().view( |
|
1, int(b * num_groups), channel_in_group, *input.size()[2:]) |
|
|
|
out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps) |
|
out = out.view(b, c, *input.size()[2:]) |
|
|
|
|
|
if weight is not None: |
|
out *= weight.view(1, -1, 1, 1) |
|
if bias is not None: |
|
out += bias.view(1, -1, 1, 1) |
|
return out |
|
|
|
|
|
def crop_valid_region(x, input_bbox, target_bbox, is_decoder): |
|
""" |
|
Crop the valid region from the tile |
|
@param x: input tile |
|
@param input_bbox: original input bounding box |
|
@param target_bbox: output bounding box |
|
@param scale: scale factor |
|
@return: cropped tile |
|
""" |
|
padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox] |
|
margin = [target_bbox[i] - padded_bbox[i] for i in range(4)] |
|
return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]] |
|
|
|
|
|
|
|
|
|
def perfcount(fn): |
|
def wrapper(*args, **kwargs): |
|
ts = time() |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.reset_peak_memory_stats(devices.device) |
|
devices.torch_gc() |
|
gc.collect() |
|
|
|
ret = fn(*args, **kwargs) |
|
|
|
devices.torch_gc() |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
vram = torch.cuda.max_memory_allocated(devices.device) / 2**20 |
|
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB') |
|
else: |
|
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s') |
|
|
|
return ret |
|
return wrapper |
|
|
|
|
|
|
|
|
|
class GroupNormParam: |
|
|
|
def __init__(self): |
|
self.var_list = [] |
|
self.mean_list = [] |
|
self.pixel_list = [] |
|
self.weight = None |
|
self.bias = None |
|
|
|
def add_tile(self, tile, layer): |
|
var, mean = get_var_mean(tile, 32) |
|
|
|
|
|
if var.dtype == torch.float16 and var.isinf().any(): |
|
fp32_tile = tile.float() |
|
var, mean = get_var_mean(fp32_tile, 32) |
|
|
|
|
|
|
|
|
|
self.var_list.append(var) |
|
self.mean_list.append(mean) |
|
self.pixel_list.append( |
|
tile.shape[2]*tile.shape[3]) |
|
if hasattr(layer, 'weight'): |
|
self.weight = layer.weight |
|
self.bias = layer.bias |
|
else: |
|
self.weight = None |
|
self.bias = None |
|
|
|
def summary(self): |
|
""" |
|
summarize the mean and var and return a function |
|
that apply group norm on each tile |
|
""" |
|
if len(self.var_list) == 0: return None |
|
|
|
var = torch.vstack(self.var_list) |
|
mean = torch.vstack(self.mean_list) |
|
max_value = max(self.pixel_list) |
|
pixels = torch.tensor(self.pixel_list, dtype=torch.float32, device=devices.device) / max_value |
|
sum_pixels = torch.sum(pixels) |
|
pixels = pixels.unsqueeze(1) / sum_pixels |
|
var = torch.sum(var * pixels, dim=0) |
|
mean = torch.sum(mean * pixels, dim=0) |
|
return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias) |
|
|
|
@staticmethod |
|
def from_tile(tile, norm): |
|
""" |
|
create a function from a single tile without summary |
|
""" |
|
var, mean = get_var_mean(tile, 32) |
|
if var.dtype == torch.float16 and var.isinf().any(): |
|
fp32_tile = tile.float() |
|
var, mean = get_var_mean(fp32_tile, 32) |
|
|
|
if var.device.type == 'mps': |
|
|
|
var = torch.clamp(var, 0, 60000) |
|
var = var.half() |
|
mean = mean.half() |
|
if hasattr(norm, 'weight'): |
|
weight = norm.weight |
|
bias = norm.bias |
|
else: |
|
weight = None |
|
bias = None |
|
|
|
def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias): |
|
return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6) |
|
return group_norm_func |
|
|
|
|
|
class VAEHook: |
|
|
|
def __init__(self, net, tile_size, is_decoder:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool, to_gpu:bool=False): |
|
self.net = net |
|
self.tile_size = tile_size |
|
self.is_decoder = is_decoder |
|
self.fast_mode = (fast_encoder and not is_decoder) or (fast_decoder and is_decoder) |
|
self.color_fix = color_fix and not is_decoder |
|
self.to_gpu = to_gpu |
|
self.pad = 11 if is_decoder else 32 |
|
|
|
def __call__(self, x): |
|
original_device = next(self.net.parameters()).device |
|
try: |
|
if self.to_gpu: |
|
self.net = self.net.to(devices.get_optimal_device()) |
|
|
|
B, C, H, W = x.shape |
|
if max(H, W) <= self.pad * 2 + self.tile_size: |
|
print("[Tiled VAE]: the input size is tiny and unnecessary to tile.") |
|
return self.net.original_forward(x) |
|
else: |
|
return self.vae_tile_forward(x) |
|
finally: |
|
self.net = self.net.to(original_device) |
|
|
|
def get_best_tile_size(self, lowerbound, upperbound): |
|
""" |
|
Get the best tile size for GPU memory |
|
""" |
|
divider = 32 |
|
while divider >= 2: |
|
remainer = lowerbound % divider |
|
if remainer == 0: |
|
return lowerbound |
|
candidate = lowerbound - remainer + divider |
|
if candidate <= upperbound: |
|
return candidate |
|
divider //= 2 |
|
return lowerbound |
|
|
|
def split_tiles(self, h, w): |
|
""" |
|
Tool function to split the image into tiles |
|
@param h: height of the image |
|
@param w: width of the image |
|
@return: tile_input_bboxes, tile_output_bboxes |
|
""" |
|
tile_input_bboxes, tile_output_bboxes = [], [] |
|
tile_size = self.tile_size |
|
pad = self.pad |
|
num_height_tiles = math.ceil((h - 2 * pad) / tile_size) |
|
num_width_tiles = math.ceil((w - 2 * pad) / tile_size) |
|
|
|
|
|
num_height_tiles = max(num_height_tiles, 1) |
|
num_width_tiles = max(num_width_tiles, 1) |
|
|
|
|
|
real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles) |
|
real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles) |
|
real_tile_height = self.get_best_tile_size(real_tile_height, tile_size) |
|
real_tile_width = self.get_best_tile_size(real_tile_width, tile_size) |
|
|
|
print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' + |
|
f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}') |
|
|
|
for i in range(num_height_tiles): |
|
for j in range(num_width_tiles): |
|
|
|
|
|
input_bbox = [ |
|
pad + j * real_tile_width, |
|
min(pad + (j + 1) * real_tile_width, w), |
|
pad + i * real_tile_height, |
|
min(pad + (i + 1) * real_tile_height, h), |
|
] |
|
|
|
|
|
output_bbox = [ |
|
input_bbox[0] if input_bbox[0] > pad else 0, |
|
input_bbox[1] if input_bbox[1] < w - pad else w, |
|
input_bbox[2] if input_bbox[2] > pad else 0, |
|
input_bbox[3] if input_bbox[3] < h - pad else h, |
|
] |
|
|
|
|
|
output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox] |
|
tile_output_bboxes.append(output_bbox) |
|
|
|
|
|
tile_input_bboxes.append([ |
|
max(0, input_bbox[0] - pad), |
|
min(w, input_bbox[1] + pad), |
|
max(0, input_bbox[2] - pad), |
|
min(h, input_bbox[3] + pad), |
|
]) |
|
|
|
return tile_input_bboxes, tile_output_bboxes |
|
|
|
@torch.no_grad() |
|
def estimate_group_norm(self, z, task_queue, color_fix): |
|
device = z.device |
|
tile = z |
|
last_id = len(task_queue) - 1 |
|
while last_id >= 0 and task_queue[last_id][0] != 'pre_norm': |
|
last_id -= 1 |
|
if last_id <= 0 or task_queue[last_id][0] != 'pre_norm': |
|
raise ValueError('No group norm found in the task queue') |
|
|
|
for i in range(last_id + 1): |
|
task = task_queue[i] |
|
if task[0] == 'pre_norm': |
|
group_norm_func = GroupNormParam.from_tile(tile, task[1]) |
|
task_queue[i] = ('apply_norm', group_norm_func) |
|
if i == last_id: |
|
return True |
|
tile = group_norm_func(tile) |
|
elif task[0] == 'store_res': |
|
task_id = i + 1 |
|
while task_id < last_id and task_queue[task_id][0] != 'add_res': |
|
task_id += 1 |
|
if task_id >= last_id: |
|
continue |
|
task_queue[task_id][1] = task[1](tile) |
|
elif task[0] == 'add_res': |
|
tile += task[1].to(device) |
|
task[1] = None |
|
elif color_fix and task[0] == 'downsample': |
|
for j in range(i, last_id + 1): |
|
if task_queue[j][0] == 'store_res': |
|
task_queue[j] = ('store_res_cpu', task_queue[j][1]) |
|
return True |
|
else: |
|
tile = task[1](tile) |
|
try: |
|
devices.test_for_nans(tile, "vae") |
|
except: |
|
print(f'Nan detected in fast mode estimation. Fast mode disabled.') |
|
return False |
|
|
|
raise IndexError('Should not reach here') |
|
|
|
@perfcount |
|
@torch.no_grad() |
|
def vae_tile_forward(self, z): |
|
""" |
|
Decode a latent vector z into an image in a tiled manner. |
|
@param z: latent vector |
|
@return: image |
|
""" |
|
device = next(self.net.parameters()).device |
|
net = self.net |
|
tile_size = self.tile_size |
|
is_decoder = self.is_decoder |
|
|
|
z = z.detach() |
|
|
|
N, height, width = z.shape[0], z.shape[2], z.shape[3] |
|
net.last_z_shape = z.shape |
|
|
|
|
|
print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}') |
|
|
|
in_bboxes, out_bboxes = self.split_tiles(height, width) |
|
|
|
|
|
tiles = [] |
|
for input_bbox in in_bboxes: |
|
tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu() |
|
tiles.append(tile) |
|
|
|
num_tiles = len(tiles) |
|
num_completed = 0 |
|
|
|
|
|
single_task_queue = build_task_queue(net, is_decoder) |
|
if self.fast_mode: |
|
|
|
|
|
scale_factor = tile_size / max(height, width) |
|
z = z.to(device) |
|
downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact') |
|
|
|
print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image') |
|
|
|
|
|
|
|
std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True) |
|
std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True) |
|
downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old |
|
del std_old, mean_old, std_new, mean_new |
|
|
|
|
|
downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max()) |
|
estimate_task_queue = clone_task_queue(single_task_queue) |
|
if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix): |
|
single_task_queue = estimate_task_queue |
|
del downsampled_z |
|
|
|
task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)] |
|
|
|
|
|
result = None |
|
result_approx = None |
|
try: |
|
with devices.autocast(): |
|
result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu() |
|
except: pass |
|
|
|
del z |
|
|
|
|
|
pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ") |
|
|
|
|
|
|
|
forward = True |
|
interrupted = False |
|
|
|
while True: |
|
if state.interrupted: interrupted = True ; break |
|
|
|
group_norm_param = GroupNormParam() |
|
for i in range(num_tiles) if forward else reversed(range(num_tiles)): |
|
if state.interrupted: interrupted = True ; break |
|
|
|
tile = tiles[i].to(device) |
|
input_bbox = in_bboxes[i] |
|
task_queue = task_queues[i] |
|
|
|
interrupted = False |
|
while len(task_queue) > 0: |
|
if state.interrupted: interrupted = True ; break |
|
|
|
|
|
|
|
task = task_queue.pop(0) |
|
if task[0] == 'pre_norm': |
|
group_norm_param.add_tile(tile, task[1]) |
|
break |
|
elif task[0] == 'store_res' or task[0] == 'store_res_cpu': |
|
task_id = 0 |
|
res = task[1](tile) |
|
if not self.fast_mode or task[0] == 'store_res_cpu': |
|
res = res.cpu() |
|
while task_queue[task_id][0] != 'add_res': |
|
task_id += 1 |
|
task_queue[task_id][1] = res |
|
elif task[0] == 'add_res': |
|
tile += task[1].to(device) |
|
task[1] = None |
|
else: |
|
tile = task[1](tile) |
|
pbar.update(1) |
|
|
|
if interrupted: break |
|
|
|
|
|
|
|
devices.test_for_nans(tile, "vae") |
|
|
|
if len(task_queue) == 0: |
|
tiles[i] = None |
|
num_completed += 1 |
|
if result is None: |
|
result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False) |
|
result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder) |
|
del tile |
|
elif i == num_tiles - 1 and forward: |
|
forward = False |
|
tiles[i] = tile |
|
elif i == 0 and not forward: |
|
forward = True |
|
tiles[i] = tile |
|
else: |
|
tiles[i] = tile.cpu() |
|
del tile |
|
|
|
if interrupted: break |
|
if num_completed == num_tiles: break |
|
|
|
|
|
group_norm_func = group_norm_param.summary() |
|
if group_norm_func is not None: |
|
for i in range(num_tiles): |
|
task_queue = task_queues[i] |
|
task_queue.insert(0, ('apply_norm', group_norm_func)) |
|
|
|
|
|
pbar.close() |
|
return result if result is not None else result_approx.to(device) |
|
|
|
|
|
class Script(scripts.Script): |
|
|
|
def __init__(self): |
|
self.hooked = False |
|
|
|
def title(self): |
|
return "Tiled VAE" |
|
|
|
def show(self, is_img2img): |
|
return scripts.AlwaysVisible |
|
|
|
def ui(self, is_img2img): |
|
tab = 't2i' if not is_img2img else 'i2i' |
|
uid = lambda name: f'MD-{tab}-{name}' |
|
|
|
with gr.Accordion('Tiled VAE', open=False, elem_id=f'MDV-{tab}'): |
|
with gr.Row() as tab_enable: |
|
enabled = gr.Checkbox(label='Enable Tiled VAE', value=False, elem_id=uid('enable')) |
|
vae_to_gpu = gr.Checkbox(label='Move VAE to GPU (if possible)', value=True, elem_id=uid('vae2gpu')) |
|
|
|
gr.HTML('<p style="margin-bottom:0.8em"> Recommended to set tile sizes as large as possible before got CUDA error: out of memory. </p>') |
|
with gr.Row() as tab_size: |
|
encoder_tile_size = gr.Slider(label='Encoder Tile Size', minimum=256, maximum=4096, step=16, value=get_rcmd_enc_tsize(), elem_id=uid('enc-size')) |
|
decoder_tile_size = gr.Slider(label='Decoder Tile Size', minimum=48, maximum=512, step=16, value=get_rcmd_dec_tsize(), elem_id=uid('dec-size')) |
|
reset = gr.Button(value='↻ Reset', variant='tool') |
|
reset.click(fn=lambda: [get_rcmd_enc_tsize(), get_rcmd_dec_tsize()], outputs=[encoder_tile_size, decoder_tile_size], show_progress=False) |
|
|
|
with gr.Row() as tab_param: |
|
fast_encoder = gr.Checkbox(label='Fast Encoder', value=True, elem_id=uid('fastenc')) |
|
color_fix = gr.Checkbox(label='Fast Encoder Color Fix', value=False, visible=True, elem_id=uid('fastenc-colorfix')) |
|
fast_decoder = gr.Checkbox(label='Fast Decoder', value=True, elem_id=uid('fastdec')) |
|
|
|
fast_encoder.change(fn=gr_show, inputs=fast_encoder, outputs=color_fix, show_progress=False) |
|
|
|
return [ |
|
enabled, |
|
encoder_tile_size, decoder_tile_size, |
|
vae_to_gpu, fast_decoder, fast_encoder, color_fix, |
|
] |
|
|
|
def process(self, p:Processing, |
|
enabled:bool, |
|
encoder_tile_size:int, decoder_tile_size:int, |
|
vae_to_gpu:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool |
|
): |
|
enabled = True |
|
encoder_tile_size = 1536 |
|
decoder_tile_size = 96 |
|
vae_to_gpu = True |
|
fast_decoder = True |
|
fast_encoder = False |
|
color_fix = False |
|
|
|
vae = p.sd_model.first_stage_model |
|
encoder = vae.encoder |
|
decoder = vae.decoder |
|
|
|
|
|
if not enabled: |
|
if self.hooked: |
|
if isinstance(encoder.forward, VAEHook): |
|
encoder.forward.net = None |
|
encoder.forward = encoder.original_forward |
|
if isinstance(decoder.forward, VAEHook): |
|
decoder.forward.net = None |
|
decoder.forward = decoder.original_forward |
|
self.hooked = False |
|
return |
|
|
|
if devices.get_optimal_device_name().startswith('cuda') and vae.device == devices.cpu and not vae_to_gpu: |
|
print("[Tiled VAE] warn: VAE is not on GPU, check 'Move VAE to GPU' if possible.") |
|
|
|
|
|
kwargs = { |
|
'fast_decoder': fast_decoder, |
|
'fast_encoder': fast_encoder, |
|
'color_fix': color_fix, |
|
'to_gpu': vae_to_gpu, |
|
} |
|
|
|
|
|
if not hasattr(encoder, 'original_forward'): setattr(encoder, 'original_forward', encoder.forward) |
|
if not hasattr(decoder, 'original_forward'): setattr(decoder, 'original_forward', decoder.forward) |
|
|
|
self.hooked = True |
|
|
|
encoder.forward = VAEHook(encoder, encoder_tile_size, is_decoder=False, **kwargs) |
|
decoder.forward = VAEHook(decoder, decoder_tile_size, is_decoder=True, **kwargs) |
|
|
|
def postprocess(self, p:Processing, processed, enabled:bool, *args): |
|
if not enabled: return |
|
|
|
vae = p.sd_model.first_stage_model |
|
encoder = vae.encoder |
|
decoder = vae.decoder |
|
if isinstance(encoder.forward, VAEHook): |
|
encoder.forward.net = None |
|
encoder.forward = encoder.original_forward |
|
if isinstance(decoder.forward, VAEHook): |
|
decoder.forward.net = None |
|
decoder.forward = decoder.original_forward |
|
|