# AuraSR: GAN-based Super-Resolution for real-world, a reproduction of the GigaGAN* paper. Implementation is # based on the unofficial lucidrains/gigagan-pytorch repository. Heavily modified from there. # # https://mingukkang.github.io/GigaGAN/ from math import log2, ceil from functools import partial from typing import Any, Optional, List, Iterable import torch from torchvision import transforms from PIL import Image from torch import nn, einsum, Tensor import torch.nn.functional as F from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange def get_same_padding(size, kernel, dilation, stride): return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 class AdaptiveConv2DMod(nn.Module): def __init__( self, dim, dim_out, kernel, *, demod=True, stride=1, dilation=1, eps=1e-8, num_conv_kernels=1, # set this to be greater than 1 for adaptive ): super().__init__() self.eps = eps self.dim_out = dim_out self.kernel = kernel self.stride = stride self.dilation = dilation self.adaptive = num_conv_kernels > 1 self.weights = nn.Parameter( torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel)) ) self.demod = demod nn.init.kaiming_normal_( self.weights, a=0, mode="fan_in", nonlinearity="leaky_relu" ) def forward( self, fmap, mod: Optional[Tensor] = None, kernel_mod: Optional[Tensor] = None ): """ notation b - batch n - convs o - output i - input k - kernel """ b, h = fmap.shape[0], fmap.shape[-2] # account for feature map that has been expanded by the scale in the first dimension # due to multiscale inputs and outputs if mod.shape[0] != b: mod = repeat(mod, "b ... -> (s b) ...", s=b // mod.shape[0]) if exists(kernel_mod): kernel_mod_has_el = kernel_mod.numel() > 0 assert self.adaptive or not kernel_mod_has_el if kernel_mod_has_el and kernel_mod.shape[0] != b: kernel_mod = repeat( kernel_mod, "b ... -> (s b) ...", s=b // kernel_mod.shape[0] ) # prepare weights for modulation weights = self.weights if self.adaptive: weights = repeat(weights, "... -> b ...", b=b) # determine an adaptive weight and 'select' the kernel to use with softmax assert exists(kernel_mod) and kernel_mod.numel() > 0 kernel_attn = kernel_mod.softmax(dim=-1) kernel_attn = rearrange(kernel_attn, "b n -> b n 1 1 1 1") weights = reduce(weights * kernel_attn, "b n ... -> b ...", "sum") # do the modulation, demodulation, as done in stylegan2 mod = rearrange(mod, "b i -> b 1 i 1 1") weights = weights * (mod + 1) if self.demod: inv_norm = ( reduce(weights**2, "b o i k1 k2 -> b o 1 1 1", "sum") .clamp(min=self.eps) .rsqrt() ) weights = weights * inv_norm fmap = rearrange(fmap, "b c h w -> 1 (b c) h w") weights = rearrange(weights, "b o ... -> (b o) ...") padding = get_same_padding(h, self.kernel, self.dilation, self.stride) fmap = F.conv2d(fmap, weights, padding=padding, groups=b) return rearrange(fmap, "1 (b o) ... -> b o ...", b=b) class Attend(nn.Module): def __init__(self, dropout=0.0, flash=False): super().__init__() self.dropout = dropout self.attn_dropout = nn.Dropout(dropout) self.scale = nn.Parameter(torch.randn(1)) self.flash = flash def flash_attn(self, q, k, v): q, k, v = map(lambda t: t.contiguous(), (q, k, v)) out = F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout if self.training else 0.0 ) return out def forward(self, q, k, v): if self.flash: return self.flash_attn(q, k, v) scale = q.shape[-1] ** -0.5 # similarity sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale # attention attn = sim.softmax(dim=-1) attn = self.attn_dropout(attn) # aggregate values out = einsum("b h i j, b h j d -> b h i d", attn, v) return out def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if callable(d) else d def cast_tuple(t, length=1): if isinstance(t, tuple): return t return (t,) * length def identity(t, *args, **kwargs): return t def is_power_of_two(n): return log2(n).is_integer() def null_iterator(): while True: yield None def Downsample(dim, dim_out=None): return nn.Sequential( Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(dim * 4, default(dim_out, dim), 1), ) class RMSNorm(nn.Module): def __init__(self, dim): super().__init__() self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.eps = 1e-4 def forward(self, x): return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5) # building block modules class Block(nn.Module): def __init__(self, dim, dim_out, groups=8, num_conv_kernels=0): super().__init__() self.proj = AdaptiveConv2DMod( dim, dim_out, kernel=3, num_conv_kernels=num_conv_kernels ) self.kernel = 3 self.dilation = 1 self.stride = 1 self.act = nn.SiLU() def forward(self, x, conv_mods_iter: Optional[Iterable] = None): conv_mods_iter = default(conv_mods_iter, null_iterator()) x = self.proj(x, mod=next(conv_mods_iter), kernel_mod=next(conv_mods_iter)) x = self.act(x) return x class ResnetBlock(nn.Module): def __init__( self, dim, dim_out, *, groups=8, num_conv_kernels=0, style_dims: List = [] ): super().__init__() style_dims.extend([dim, num_conv_kernels, dim_out, num_conv_kernels]) self.block1 = Block( dim, dim_out, groups=groups, num_conv_kernels=num_conv_kernels ) self.block2 = Block( dim_out, dim_out, groups=groups, num_conv_kernels=num_conv_kernels ) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, conv_mods_iter: Optional[Iterable] = None): h = self.block1(x, conv_mods_iter=conv_mods_iter) h = self.block2(h, conv_mods_iter=conv_mods_iter) return h + self.res_conv(x) class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads self.norm = RMSNorm(dim) self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim)) def forward(self, x): b, c, h, w = x.shape x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv ) q = q.softmax(dim=-2) k = k.softmax(dim=-1) q = q * self.scale context = torch.einsum("b h d n, b h e n -> b h d e", k, v) out = torch.einsum("b h d e, b h d n -> b h e n", context, q) out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) return self.to_out(out) class Attention(nn.Module): def __init__(self, dim, heads=4, dim_head=32, flash=False): super().__init__() self.heads = heads hidden_dim = dim_head * heads self.norm = RMSNorm(dim) self.attend = Attend(flash=flash) self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( lambda t: rearrange(t, "b (h c) x y -> b h (x y) c", h=self.heads), qkv ) out = self.attend(q, k, v) out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) return self.to_out(out) # feedforward def FeedForward(dim, mult=4): return nn.Sequential( RMSNorm(dim), nn.Conv2d(dim, dim * mult, 1), nn.GELU(), nn.Conv2d(dim * mult, dim, 1), ) # transformers class Transformer(nn.Module): def __init__(self, dim, dim_head=64, heads=8, depth=1, flash_attn=True, ff_mult=4): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ Attention( dim=dim, dim_head=dim_head, heads=heads, flash=flash_attn ), FeedForward(dim=dim, mult=ff_mult), ] ) ) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x class LinearTransformer(nn.Module): def __init__(self, dim, dim_head=64, heads=8, depth=1, ff_mult=4): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ LinearAttention(dim=dim, dim_head=dim_head, heads=heads), FeedForward(dim=dim, mult=ff_mult), ] ) ) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x class NearestNeighborhoodUpsample(nn.Module): def __init__(self, dim, dim_out=None): super().__init__() dim_out = default(dim_out, dim) self.conv = nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1) def forward(self, x): if x.shape[0] >= 64: x = x.contiguous() x = F.interpolate(x, scale_factor=2.0, mode="nearest") x = self.conv(x) return x class EqualLinear(nn.Module): def __init__(self, dim, dim_out, lr_mul=1, bias=True): super().__init__() self.weight = nn.Parameter(torch.randn(dim_out, dim)) if bias: self.bias = nn.Parameter(torch.zeros(dim_out)) self.lr_mul = lr_mul def forward(self, input): return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) class StyleGanNetwork(nn.Module): def __init__(self, dim_in=128, dim_out=512, depth=8, lr_mul=0.1, dim_text_latent=0): super().__init__() self.dim_in = dim_in self.dim_out = dim_out self.dim_text_latent = dim_text_latent layers = [] for i in range(depth): is_first = i == 0 if is_first: dim_in_layer = dim_in + dim_text_latent else: dim_in_layer = dim_out dim_out_layer = dim_out layers.extend( [EqualLinear(dim_in_layer, dim_out_layer, lr_mul), nn.LeakyReLU(0.2)] ) self.net = nn.Sequential(*layers) def forward(self, x, text_latent=None): x = F.normalize(x, dim=1) if self.dim_text_latent > 0: assert exists(text_latent) x = torch.cat((x, text_latent), dim=-1) return self.net(x) class UnetUpsampler(torch.nn.Module): def __init__( self, dim: int, *, image_size: int, input_image_size: int, init_dim: Optional[int] = None, out_dim: Optional[int] = None, style_network: Optional[dict] = None, up_dim_mults: tuple = (1, 2, 4, 8, 16), down_dim_mults: tuple = (4, 8, 16), channels: int = 3, resnet_block_groups: int = 8, full_attn: tuple = (False, False, False, True, True), flash_attn: bool = True, self_attn_dim_head: int = 64, self_attn_heads: int = 8, attn_depths: tuple = (2, 2, 2, 2, 4), mid_attn_depth: int = 4, num_conv_kernels: int = 4, resize_mode: str = "bilinear", unconditional: bool = True, skip_connect_scale: Optional[float] = None, ): super().__init__() self.style_network = style_network = StyleGanNetwork(**style_network) self.unconditional = unconditional assert not ( unconditional and exists(style_network) and style_network.dim_text_latent > 0 ) assert is_power_of_two(image_size) and is_power_of_two( input_image_size ), "both output image size and input image size must be power of 2" assert ( input_image_size < image_size ), "input image size must be smaller than the output image size, thus upsampling" self.image_size = image_size self.input_image_size = input_image_size style_embed_split_dims = [] self.channels = channels input_channels = channels init_dim = default(init_dim, dim) up_dims = [init_dim, *map(lambda m: dim * m, up_dim_mults)] init_down_dim = up_dims[len(up_dim_mults) - len(down_dim_mults)] down_dims = [init_down_dim, *map(lambda m: dim * m, down_dim_mults)] self.init_conv = nn.Conv2d(input_channels, init_down_dim, 7, padding=3) up_in_out = list(zip(up_dims[:-1], up_dims[1:])) down_in_out = list(zip(down_dims[:-1], down_dims[1:])) block_klass = partial( ResnetBlock, groups=resnet_block_groups, num_conv_kernels=num_conv_kernels, style_dims=style_embed_split_dims, ) FullAttention = partial(Transformer, flash_attn=flash_attn) *_, mid_dim = up_dims self.skip_connect_scale = default(skip_connect_scale, 2**-0.5) self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) block_count = 6 for ind, ( (dim_in, dim_out), layer_full_attn, layer_attn_depth, ) in enumerate(zip(down_in_out, full_attn, attn_depths)): attn_klass = FullAttention if layer_full_attn else LinearTransformer blocks = [] for i in range(block_count): blocks.append(block_klass(dim_in, dim_in)) self.downs.append( nn.ModuleList( [ nn.ModuleList(blocks), nn.ModuleList( [ ( attn_klass( dim_in, dim_head=self_attn_dim_head, heads=self_attn_heads, depth=layer_attn_depth, ) if layer_full_attn else None ), nn.Conv2d( dim_in, dim_out, kernel_size=3, stride=2, padding=1 ), ] ), ] ) ) self.mid_block1 = block_klass(mid_dim, mid_dim) self.mid_attn = FullAttention( mid_dim, dim_head=self_attn_dim_head, heads=self_attn_heads, depth=mid_attn_depth, ) self.mid_block2 = block_klass(mid_dim, mid_dim) *_, last_dim = up_dims for ind, ( (dim_in, dim_out), layer_full_attn, layer_attn_depth, ) in enumerate( zip( reversed(up_in_out), reversed(full_attn), reversed(attn_depths), ) ): attn_klass = FullAttention if layer_full_attn else LinearTransformer blocks = [] input_dim = dim_in * 2 if ind < len(down_in_out) else dim_in for i in range(block_count): blocks.append(block_klass(input_dim, dim_in)) self.ups.append( nn.ModuleList( [ nn.ModuleList(blocks), nn.ModuleList( [ NearestNeighborhoodUpsample( last_dim if ind == 0 else dim_out, dim_in, ), ( attn_klass( dim_in, dim_head=self_attn_dim_head, heads=self_attn_heads, depth=layer_attn_depth, ) if layer_full_attn else None ), ] ), ] ) ) self.out_dim = default(out_dim, channels) self.final_res_block = block_klass(dim, dim) self.final_to_rgb = nn.Conv2d(dim, channels, 1) self.resize_mode = resize_mode self.style_to_conv_modulations = nn.Linear( style_network.dim_out, sum(style_embed_split_dims) ) self.style_embed_split_dims = style_embed_split_dims @property def allowable_rgb_resolutions(self): input_res_base = int(log2(self.input_image_size)) output_res_base = int(log2(self.image_size)) allowed_rgb_res_base = list(range(input_res_base, output_res_base)) return [*map(lambda p: 2**p, allowed_rgb_res_base)] @property def device(self): return next(self.parameters()).device @property def total_params(self): return sum([p.numel() for p in self.parameters()]) def resize_image_to(self, x, size): return F.interpolate(x, (size, size), mode=self.resize_mode) def forward( self, lowres_image: torch.Tensor, styles: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, global_text_tokens: Optional[torch.Tensor] = None, return_all_rgbs: bool = False, ): x = lowres_image noise_scale = 0.001 # Adjust the scale of the noise as needed noise_aug = torch.randn_like(x) * noise_scale x = x + noise_aug x = x.clamp(0, 1) shape = x.shape batch_size = shape[0] assert shape[-2:] == ((self.input_image_size,) * 2) # styles if not exists(styles): assert exists(self.style_network) noise = default( noise, torch.randn( (batch_size, self.style_network.dim_in), device=self.device ), ) styles = self.style_network(noise, global_text_tokens) # project styles to conv modulations conv_mods = self.style_to_conv_modulations(styles) conv_mods = conv_mods.split(self.style_embed_split_dims, dim=-1) conv_mods = iter(conv_mods) x = self.init_conv(x) h = [] for blocks, (attn, downsample) in self.downs: for block in blocks: x = block(x, conv_mods_iter=conv_mods) h.append(x) if attn is not None: x = attn(x) x = downsample(x) x = self.mid_block1(x, conv_mods_iter=conv_mods) x = self.mid_attn(x) x = self.mid_block2(x, conv_mods_iter=conv_mods) for ( blocks, ( upsample, attn, ), ) in self.ups: x = upsample(x) for block in blocks: if h != []: res = h.pop() res = res * self.skip_connect_scale x = torch.cat((x, res), dim=1) x = block(x, conv_mods_iter=conv_mods) if attn is not None: x = attn(x) x = self.final_res_block(x, conv_mods_iter=conv_mods) rgb = self.final_to_rgb(x) if not return_all_rgbs: return rgb return rgb, [] def tile_image(image, chunk_size=64): c, h, w = image.shape h_chunks = ceil(h / chunk_size) w_chunks = ceil(w / chunk_size) tiles = [] for i in range(h_chunks): for j in range(w_chunks): tile = image[:, i * chunk_size:(i + 1) * chunk_size, j * chunk_size:(j + 1) * chunk_size] tiles.append(tile) return tiles, h_chunks, w_chunks def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64): # Determine the shape of the output tensor c = tiles[0].shape[0] h = h_chunks * chunk_size w = w_chunks * chunk_size # Create an empty tensor to hold the merged image merged = torch.zeros((c, h, w), dtype=tiles[0].dtype) # Iterate over the tiles and place them in the correct position for idx, tile in enumerate(tiles): i = idx // w_chunks j = idx % w_chunks h_start = i * chunk_size w_start = j * chunk_size tile_h, tile_w = tile.shape[1:] merged[:, h_start:h_start+tile_h, w_start:w_start+tile_w] = tile return merged class AuraSR: def __init__(self, config: dict[str, Any], device: str = "cuda"): self.upsampler = UnetUpsampler(**config).to(device) self.input_image_size = config["input_image_size"] @classmethod def from_pretrained(cls, model_id: str = "fal-ai/AuraSR",device: str="cuda",use_safetensors: bool = True): import json import torch from pathlib import Path from huggingface_hub import snapshot_download # Check if model_id is a local file if Path(model_id).is_file(): local_file = Path(model_id) if local_file.suffix == '.safetensors': use_safetensors = True elif local_file.suffix == '.ckpt': use_safetensors = False else: raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.") # For local files, we need to provide the config separately config_path = local_file.with_name('config.json') if not config_path.exists(): raise FileNotFoundError( f"Config file not found: {config_path}. " f"When loading from a local file, ensure that 'config.json' " f"is present in the same directory as '{local_file.name}'. " f"If you're trying to load a model from Hugging Face, " f"please provide the model ID instead of a file path." ) config = json.loads(config_path.read_text()) hf_model_path = local_file.parent else: hf_model_path = Path(snapshot_download(model_id)) config = json.loads((hf_model_path / "config.json").read_text()) model = cls(config,device) if use_safetensors: try: from safetensors.torch import load_file checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id) except ImportError: raise ImportError( "The safetensors library is not installed. " "Please install it with `pip install safetensors` " "or use `use_safetensors=False` to load the model with PyTorch." ) else: checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id) model.upsampler.load_state_dict(checkpoint, strict=True) return model @torch.no_grad() def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image: tensor_transform = transforms.ToTensor() device = self.upsampler.device image_tensor = tensor_transform(image).unsqueeze(0) _, _, h, w = image_tensor.shape pad_h = (self.input_image_size - h % self.input_image_size) % self.input_image_size pad_w = (self.input_image_size - w % self.input_image_size) % self.input_image_size # Pad the image image_tensor = torch.nn.functional.pad(image_tensor, (0, pad_w, 0, pad_h), mode='reflect').squeeze(0) tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size) # Batch processing of tiles num_tiles = len(tiles) batches = [tiles[i:i + max_batch_size] for i in range(0, num_tiles, max_batch_size)] reconstructed_tiles = [] for batch in batches: model_input = torch.stack(batch).to(device) generator_output = self.upsampler( lowres_image=model_input, noise=torch.randn(model_input.shape[0], 128, device=device) ) reconstructed_tiles.extend(list(generator_output.clamp_(0, 1).detach().cpu())) merged_tensor = merge_tiles(reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4) unpadded = merged_tensor[:, :h * 4, :w * 4] to_pil = transforms.ToPILImage() return to_pil(unpadded)