DiffIR2VR / utils /common.py
jimmycv07's picture
first commit
1de8821
from typing import Mapping, Any, Tuple, Callable
import importlib
import os
from urllib.parse import urlparse
import torch
from torch import Tensor
from torch.nn import functional as F
import numpy as np
from torch.hub import download_url_to_file, get_dir
def get_obj_from_str(string: str, reload: bool=False) -> Any:
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config: Mapping[str, Any]) -> Any:
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
# import ipdb; ipdb.set_trace()
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
def sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]:
hi_list = list(range(0, h - tile_size + 1, tile_stride))
if (h - tile_size) % tile_stride != 0:
hi_list.append(h - tile_size)
wi_list = list(range(0, w - tile_size + 1, tile_stride))
if (w - tile_size) % tile_stride != 0:
wi_list.append(w - tile_size)
coords = []
for hi in hi_list:
for wi in wi_list:
coords.append((hi, hi + tile_size, wi, wi + tile_size))
return coords
# https://github.com/csslc/CCSR/blob/main/model/q_sampler.py#L503
def gaussian_weights(tile_width: int, tile_height: int) -> np.ndarray:
"""Generates a gaussian mask of weights for tile contributions"""
latent_width = tile_width
latent_height = tile_height
var = 0.01
midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
x_probs = [
np.exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / np.sqrt(2 * np.pi * var)
for x in range(latent_width)]
midpoint = latent_height / 2
y_probs = [
np.exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / np.sqrt(2 * np.pi * var)
for y in range(latent_height)]
weights = np.outer(y_probs, x_probs)
return weights
COUNT_VRAM = bool(os.environ.get("COUNT_VRAM", False))
def count_vram_usage(func: Callable) -> Callable:
if not COUNT_VRAM:
return func
def wrapper(*args, **kwargs):
peak_before = torch.cuda.max_memory_allocated() / (1024 ** 3)
ret = func(*args, **kwargs)
torch.cuda.synchronize()
peak_after = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f"VRAM peak before {func.__name__}: {peak_before:.5f} GB, after: {peak_after:.5f} GB")
return ret
return wrapper