New logic
Browse files- app.py +72 -118
- model/__init__.py +2 -0
- model/convnext.py +55 -0
- model/edsr.py +122 -0
- model/hyper.py +41 -0
- model/init.py +24 -0
- model/rdn.py +72 -0
- model/swin_ir.py +532 -0
- model/tail.py +18 -0
- model/thera.py +175 -0
- requirements.txt +36 -5
- super_resolve.py +99 -0
- utils.py +36 -0
app.py
CHANGED
@@ -1,137 +1,91 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
-
from
|
6 |
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
7 |
-
from
|
8 |
-
|
9 |
-
|
10 |
-
#
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
#
|
16 |
-
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
)
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
])
|
53 |
-
|
54 |
-
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
|
55 |
-
with torch.no_grad():
|
56 |
-
output = model(input_tensor)
|
57 |
-
|
58 |
-
output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(-1, 1) * 0.5 + 0.5)
|
59 |
-
return output_img
|
60 |
-
|
61 |
|
62 |
-
|
63 |
-
inputs = feature_extractor(
|
64 |
with torch.no_grad():
|
65 |
-
outputs =
|
66 |
-
|
67 |
-
|
68 |
-
prediction = torch.nn.functional.interpolate(
|
69 |
-
predicted_depth.unsqueeze(1),
|
70 |
-
size=image.size[::-1],
|
71 |
-
mode="bicubic",
|
72 |
-
align_corners=False,
|
73 |
-
)
|
74 |
-
return prediction.squeeze().cpu().numpy()
|
75 |
-
|
76 |
-
|
77 |
-
def create_bas_relief(prompt, image, depth_map, pipe):
|
78 |
-
control_image = Image.fromarray((depth_map * 255).astype(np.uint8))
|
79 |
-
|
80 |
-
image = image.resize((1024, 1024))
|
81 |
-
control_image = control_image.resize((1024, 1024))
|
82 |
-
|
83 |
-
result = pipe(
|
84 |
-
prompt=prompt,
|
85 |
-
image=image,
|
86 |
-
control_image=control_image,
|
87 |
-
strength=0.8,
|
88 |
-
num_inference_steps=30
|
89 |
-
).images[0]
|
90 |
|
91 |
-
|
|
|
|
|
|
|
|
|
92 |
|
|
|
93 |
|
94 |
-
# --- Interface Gradio ---
|
95 |
|
96 |
-
|
97 |
-
|
|
|
98 |
|
99 |
with gr.Row():
|
100 |
with gr.Column():
|
101 |
-
|
102 |
-
prompt = gr.Textbox("
|
103 |
-
|
|
|
104 |
|
105 |
with gr.Column():
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
depth_model = load_depth_model()
|
115 |
-
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
116 |
-
basrelief_pipe = load_controlnet()
|
117 |
-
|
118 |
-
# 1. Super Resolução
|
119 |
-
upscaled = run_thera(image, thera_model)
|
120 |
-
|
121 |
-
# 2. Depth Map
|
122 |
-
depth = create_depth_map(upscaled, depth_model, feature_extractor)
|
123 |
-
depth_normalized = (depth - depth.min()) / (depth.max() - depth.min())
|
124 |
-
|
125 |
-
# 3. Bas-Relief
|
126 |
-
basrelief = create_bas_relief(prompt, upscaled, depth_normalized, basrelief_pipe)
|
127 |
-
|
128 |
-
return upscaled, depth_normalized, basrelief
|
129 |
-
|
130 |
-
|
131 |
-
submit_btn.click(
|
132 |
-
process,
|
133 |
-
inputs=[input_image, prompt],
|
134 |
-
outputs=[upscaled_output, depth_output, basrelief_output]
|
135 |
)
|
136 |
|
137 |
if __name__ == "__main__":
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
import jax
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
+
from diffusers import StableDiffusionXLImg2ImgPipeline
|
7 |
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
8 |
+
from super_resolve import process as thera_process # Assume imports do Thera
|
9 |
+
|
10 |
+
# Configurações
|
11 |
+
DEVICE = "cpu" # ou "cuda" se disponível
|
12 |
+
JAX_DEVICE = jax.devices("cpu")[0] # Usar CPU para JAX
|
13 |
+
|
14 |
+
# 1. Carregar modelos do Thera (EDSR/RDN)
|
15 |
+
# (Implementar conforme código original do Thera)
|
16 |
+
model_edsr, params_edsr = None, None # Carregar usando pickle/HF Hub
|
17 |
+
|
18 |
+
# 2. Carregar SDXL Img2Img + LoRA
|
19 |
+
print("Carregando SDXL Img2Img com LoRA...")
|
20 |
+
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
21 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
22 |
+
torch_dtype=torch.float32
|
23 |
+
).to(DEVICE)
|
24 |
+
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
|
25 |
+
|
26 |
+
# 3. Carregar modelo de profundidade
|
27 |
+
print("Carregando DPT...")
|
28 |
+
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
29 |
+
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(DEVICE)
|
30 |
+
|
31 |
+
|
32 |
+
def enhance_depth_map(depth_arr):
|
33 |
+
depth_normalized = (depth_arr - depth_arr.min()) / (depth_arr.max() - depth_arr.min() + 1e-8)
|
34 |
+
return Image.fromarray((depth_normalized * 255).astype(np.uint8))
|
35 |
+
|
36 |
+
|
37 |
+
def full_pipeline(image, prompt, scale_factor=2.0):
|
38 |
+
# 1. Super Resolução com Thera
|
39 |
+
source = np.array(image) / 255.0
|
40 |
+
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
|
41 |
+
upscaled = thera_process(source, model_edsr, params_edsr, target_shape, do_ensemble=True)
|
42 |
+
upscaled_pil = Image.fromarray((upscaled * 255).astype(np.uint8))
|
43 |
+
|
44 |
+
# 2. Gerar Bas-Relief com SDXL Img2Img
|
45 |
+
full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief"
|
46 |
+
bas_relief = pipe(
|
47 |
+
prompt=full_prompt,
|
48 |
+
image=upscaled_pil,
|
49 |
+
strength=0.7,
|
50 |
+
num_inference_steps=25,
|
51 |
+
guidance_scale=7.5
|
52 |
+
).images[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
# 3. Calcular Depth Map
|
55 |
+
inputs = feature_extractor(bas_relief, return_tensors="pt").to(DEVICE)
|
56 |
with torch.no_grad():
|
57 |
+
outputs = depth_model(**inputs)
|
58 |
+
depth = outputs.predicted_depth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
depth_map = torch.nn.functional.interpolate(
|
61 |
+
depth.unsqueeze(1),
|
62 |
+
size=bas_relief.size[::-1],
|
63 |
+
mode="bicubic"
|
64 |
+
).squeeze().cpu().numpy()
|
65 |
|
66 |
+
return upscaled_pil, bas_relief, enhance_depth_map(depth_map)
|
67 |
|
|
|
68 |
|
69 |
+
# Interface Gradio
|
70 |
+
with gr.Blocks(title="Super Resolução + Bas-Relief") as app:
|
71 |
+
gr.Markdown("## 📈 Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
|
72 |
|
73 |
with gr.Row():
|
74 |
with gr.Column():
|
75 |
+
img_input = gr.Image(type="pil", label="Imagem de Entrada")
|
76 |
+
prompt = gr.Textbox("ancient sculpture, marble", label="Descrição do Relevo")
|
77 |
+
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
|
78 |
+
btn = gr.Button("Processar")
|
79 |
|
80 |
with gr.Column():
|
81 |
+
img_upscaled = gr.Image(label="Imagem Super Resolvida")
|
82 |
+
img_basrelief = gr.Image(label="Relevo Escultural")
|
83 |
+
img_depth = gr.Image(label="Mapa de Profundidade")
|
84 |
+
|
85 |
+
btn.click(
|
86 |
+
full_pipeline,
|
87 |
+
inputs=[img_input, prompt, scale],
|
88 |
+
outputs=[img_upscaled, img_basrelief, img_depth]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
)
|
90 |
|
91 |
if __name__ == "__main__":
|
model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .hyper import Hypernetwork
|
2 |
+
from .thera import build_thera
|
model/convnext.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import flax.linen as nn
|
2 |
+
from jaxtyping import Array, ArrayLike
|
3 |
+
|
4 |
+
|
5 |
+
class ConvNeXtBlock(nn.Module):
|
6 |
+
"""ConvNext block. See Fig.4 in "A ConvNet for the 2020s" by Liu et al.
|
7 |
+
|
8 |
+
https://openaccess.thecvf.com/content/CVPR2022/papers/Liu_A_ConvNet_for_the_2020s_CVPR_2022_paper.pdf
|
9 |
+
"""
|
10 |
+
n_dims: int = 64
|
11 |
+
kernel_size: int = 3 # 7 in the paper's version
|
12 |
+
group_features: bool = False
|
13 |
+
|
14 |
+
def setup(self) -> None:
|
15 |
+
self.residual = nn.Sequential([
|
16 |
+
nn.Conv(self.n_dims, kernel_size=(self.kernel_size, self.kernel_size), use_bias=False,
|
17 |
+
feature_group_count=self.n_dims if self.group_features else 1),
|
18 |
+
nn.LayerNorm(),
|
19 |
+
nn.Conv(4 * self.n_dims, kernel_size=(1, 1)),
|
20 |
+
nn.gelu,
|
21 |
+
nn.Conv(self.n_dims, kernel_size=(1, 1)),
|
22 |
+
])
|
23 |
+
|
24 |
+
def __call__(self, x: ArrayLike) -> Array:
|
25 |
+
return x + self.residual(x)
|
26 |
+
|
27 |
+
|
28 |
+
class Projection(nn.Module):
|
29 |
+
n_dims: int
|
30 |
+
|
31 |
+
@nn.compact
|
32 |
+
def __call__(self, x: ArrayLike) -> Array:
|
33 |
+
x = nn.LayerNorm()(x)
|
34 |
+
x = nn.Conv(self.n_dims, (1, 1))(x)
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
class ConvNeXt(nn.Module):
|
39 |
+
block_defs: list[tuple]
|
40 |
+
|
41 |
+
def setup(self) -> None:
|
42 |
+
layers = []
|
43 |
+
current_size = self.block_defs[0][0]
|
44 |
+
for block_def in self.block_defs:
|
45 |
+
if block_def[0] != current_size:
|
46 |
+
layers.append(Projection(block_def[0]))
|
47 |
+
layers.append(ConvNeXtBlock(*block_def))
|
48 |
+
current_size = block_def[0]
|
49 |
+
self.layers = layers
|
50 |
+
|
51 |
+
def __call__(self, x: ArrayLike, _: bool) -> Array:
|
52 |
+
for layer in self.layers:
|
53 |
+
x = layer(x)
|
54 |
+
return x
|
55 |
+
|
model/edsr.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from https://github.com/isaaccorley/jax-enhance
|
2 |
+
|
3 |
+
from functools import partial
|
4 |
+
from typing import Any, Sequence, Callable
|
5 |
+
|
6 |
+
import jax.numpy as jnp
|
7 |
+
import flax.linen as nn
|
8 |
+
from flax.core.frozen_dict import freeze
|
9 |
+
import einops
|
10 |
+
|
11 |
+
|
12 |
+
class PixelShuffle(nn.Module):
|
13 |
+
scale_factor: int
|
14 |
+
|
15 |
+
def setup(self):
|
16 |
+
self.layer = partial(
|
17 |
+
einops.rearrange,
|
18 |
+
pattern="b h w (c h2 w2) -> b (h h2) (w w2) c",
|
19 |
+
h2=self.scale_factor,
|
20 |
+
w2=self.scale_factor
|
21 |
+
)
|
22 |
+
|
23 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
24 |
+
return self.layer(x)
|
25 |
+
|
26 |
+
|
27 |
+
class ResidualBlock(nn.Module):
|
28 |
+
channels: int
|
29 |
+
kernel_size: Sequence[int]
|
30 |
+
res_scale: float
|
31 |
+
activation: Callable
|
32 |
+
dtype: Any = jnp.float32
|
33 |
+
|
34 |
+
def setup(self):
|
35 |
+
self.body = nn.Sequential([
|
36 |
+
nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype),
|
37 |
+
self.activation,
|
38 |
+
nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype),
|
39 |
+
])
|
40 |
+
|
41 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
42 |
+
return x + self.body(x)
|
43 |
+
|
44 |
+
|
45 |
+
class UpsampleBlock(nn.Module):
|
46 |
+
num_upsamples: int
|
47 |
+
channels: int
|
48 |
+
kernel_size: Sequence[int]
|
49 |
+
dtype: Any = jnp.float32
|
50 |
+
|
51 |
+
def setup(self):
|
52 |
+
layers = []
|
53 |
+
for _ in range(self.num_upsamples):
|
54 |
+
layers.extend([
|
55 |
+
nn.Conv(features=self.channels * 2 ** 2, kernel_size=self.kernel_size, dtype=self.dtype),
|
56 |
+
PixelShuffle(scale_factor=2),
|
57 |
+
])
|
58 |
+
self.layers = layers
|
59 |
+
|
60 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
61 |
+
for layer in self.layers:
|
62 |
+
x = layer(x)
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class EDSR(nn.Module):
|
67 |
+
"""Enhanced Deep Residual Networks for Single Image Super-Resolution https://arxiv.org/pdf/1707.02921v1.pdf"""
|
68 |
+
scale_factor: int
|
69 |
+
channels: int = 3
|
70 |
+
num_blocks: int = 32
|
71 |
+
num_feats: int = 256
|
72 |
+
dtype: Any = jnp.float32
|
73 |
+
|
74 |
+
def setup(self):
|
75 |
+
# pre res blocks layer
|
76 |
+
self.head = nn.Sequential([nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype)])
|
77 |
+
|
78 |
+
# res blocks
|
79 |
+
res_blocks = [
|
80 |
+
ResidualBlock(channels=self.num_feats, kernel_size=(3, 3), res_scale=0.1, activation=nn.relu, dtype=self.dtype)
|
81 |
+
for i in range(self.num_blocks)
|
82 |
+
]
|
83 |
+
res_blocks.append(nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype))
|
84 |
+
self.body = nn.Sequential(res_blocks)
|
85 |
+
|
86 |
+
def __call__(self, x: jnp.ndarray, _=None) -> jnp.ndarray:
|
87 |
+
x = self.head(x)
|
88 |
+
x = x + self.body(x)
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
def convert_edsr_checkpoint(torch_dict, no_upsampling=True):
|
93 |
+
def convert(in_dict):
|
94 |
+
top_keys = set([k.split('.')[0] for k in in_dict.keys()])
|
95 |
+
leaves = set([k for k in in_dict.keys() if '.' not in k])
|
96 |
+
|
97 |
+
# convert leaves
|
98 |
+
out_dict = {}
|
99 |
+
for l in leaves:
|
100 |
+
if l == 'weight':
|
101 |
+
out_dict['kernel'] = jnp.asarray(in_dict[l]).transpose((2, 3, 1, 0))
|
102 |
+
elif l == 'bias':
|
103 |
+
out_dict[l] = jnp.asarray(in_dict[l])
|
104 |
+
else:
|
105 |
+
out_dict[l] = in_dict[l]
|
106 |
+
|
107 |
+
for top_key in top_keys.difference(leaves):
|
108 |
+
new_top_key = 'layers_' + top_key if top_key.isdigit() else top_key
|
109 |
+
out_dict[new_top_key] = convert(
|
110 |
+
{k[len(top_key) + 1:]: v for k, v in in_dict.items() if k.startswith(top_key)})
|
111 |
+
return out_dict
|
112 |
+
|
113 |
+
converted = convert(torch_dict)
|
114 |
+
|
115 |
+
# remove unwanted keys
|
116 |
+
if no_upsampling:
|
117 |
+
del converted['tail']
|
118 |
+
|
119 |
+
for k in ('add_mean', 'sub_mean'):
|
120 |
+
del converted[k]
|
121 |
+
|
122 |
+
return freeze(converted)
|
model/hyper.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import flax.linen as nn
|
6 |
+
from jaxtyping import Array, ArrayLike, PyTreeDef
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from utils import interpolate_grid
|
10 |
+
|
11 |
+
|
12 |
+
class Hypernetwork(nn.Module):
|
13 |
+
encoder: nn.Module
|
14 |
+
refine: nn.Module
|
15 |
+
output_params_shape: list[tuple] # e.g. [(16,), (32, 32), ...]
|
16 |
+
tree_def: PyTreeDef # used to reconstruct the parameter sets
|
17 |
+
|
18 |
+
def setup(self):
|
19 |
+
# one layer 1x1 conv to calculate field params, as in SIREN paper
|
20 |
+
output_size = sum(math.prod(s) for s in self.output_params_shape)
|
21 |
+
self.out_conv = nn.Conv(output_size, kernel_size=(1, 1), use_bias=True)
|
22 |
+
|
23 |
+
def get_encoding(self, source: ArrayLike, training=False) -> Array:
|
24 |
+
"""Convenience method for whole-image evaluation"""
|
25 |
+
return self.refine(self.encoder(source, training), training)
|
26 |
+
|
27 |
+
def get_params_at_coords(self, encoding: ArrayLike, coords: ArrayLike) -> Array:
|
28 |
+
encoding = interpolate_grid(coords, encoding)
|
29 |
+
phi_params = self.out_conv(encoding)
|
30 |
+
|
31 |
+
# reshape to output params shape
|
32 |
+
phi_params = jnp.split(
|
33 |
+
phi_params, np.cumsum([math.prod(s) for s in self.output_params_shape[:-1]]), axis=-1)
|
34 |
+
phi_params = [jnp.reshape(p, p.shape[:-1] + s) for p, s in
|
35 |
+
zip(phi_params, self.output_params_shape)]
|
36 |
+
|
37 |
+
return jax.tree_util.tree_unflatten(self.tree_def, phi_params)
|
38 |
+
|
39 |
+
def __call__(self, source: ArrayLike, target_coords: ArrayLike, training=False) -> Array:
|
40 |
+
encoding = self.get_encoding(source, training)
|
41 |
+
return self.get_params_at_coords(encoding, target_coords)
|
model/init.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
from jaxtyping import Array
|
6 |
+
|
7 |
+
|
8 |
+
def uniform_between(a: float, b: float, dtype=jnp.float32) -> Callable:
|
9 |
+
def init(key, shape, dtype=dtype) -> Array:
|
10 |
+
return jax.random.uniform(key, shape, dtype=dtype, minval=a, maxval=b)
|
11 |
+
return init
|
12 |
+
|
13 |
+
|
14 |
+
def linear_up(scale: float) -> Callable:
|
15 |
+
def init(key, shape, dtype=jnp.float32) -> Array:
|
16 |
+
assert shape[-2] == 2
|
17 |
+
keys = jax.random.split(key, 2)
|
18 |
+
norm = jnp.pi * scale * (
|
19 |
+
jax.random.uniform(keys[0], shape=(1, shape[-1])) ** .5)
|
20 |
+
theta = 2 * jnp.pi * jax.random.uniform(keys[1], shape=(1, shape[-1]))
|
21 |
+
x = norm * jnp.cos(theta)
|
22 |
+
y = norm * jnp.sin(theta)
|
23 |
+
return jnp.concatenate([x, y], axis=-2).astype(dtype)
|
24 |
+
return init
|
model/rdn.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Residual Dense Network for Image Super-Resolution
|
2 |
+
# https://arxiv.org/abs/1802.08797
|
3 |
+
# modified from: https://github.com/thstkdgus35/EDSR-PyTorch
|
4 |
+
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import flax.linen as nn
|
7 |
+
|
8 |
+
|
9 |
+
class RDB_Conv(nn.Module):
|
10 |
+
growRate: int
|
11 |
+
kSize: int = 3
|
12 |
+
|
13 |
+
@nn.compact
|
14 |
+
def __call__(self, x):
|
15 |
+
out = nn.Sequential([
|
16 |
+
nn.Conv(self.growRate, (self.kSize, self.kSize), padding=(self.kSize-1)//2),
|
17 |
+
nn.activation.relu
|
18 |
+
])(x)
|
19 |
+
return jnp.concatenate((x, out), -1)
|
20 |
+
|
21 |
+
|
22 |
+
class RDB(nn.Module):
|
23 |
+
growRate0: int
|
24 |
+
growRate: int
|
25 |
+
nConvLayers: int
|
26 |
+
|
27 |
+
@nn.compact
|
28 |
+
def __call__(self, x):
|
29 |
+
res = x
|
30 |
+
|
31 |
+
for c in range(self.nConvLayers):
|
32 |
+
x = RDB_Conv(self.growRate)(x)
|
33 |
+
|
34 |
+
x = nn.Conv(self.growRate0, (1, 1))(x)
|
35 |
+
|
36 |
+
return x + res
|
37 |
+
|
38 |
+
|
39 |
+
class RDN(nn.Module):
|
40 |
+
G0: int = 64
|
41 |
+
RDNkSize: int = 3
|
42 |
+
RDNconfig: str = 'B'
|
43 |
+
scale: int = 2
|
44 |
+
n_colors: int = 3
|
45 |
+
|
46 |
+
@nn.compact
|
47 |
+
def __call__(self, x, _=None):
|
48 |
+
D, C, G = {
|
49 |
+
'A': (20, 6, 32),
|
50 |
+
'B': (16, 8, 64),
|
51 |
+
}[self.RDNconfig]
|
52 |
+
|
53 |
+
# Shallow feature extraction
|
54 |
+
f_1 = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(x)
|
55 |
+
x = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(f_1)
|
56 |
+
|
57 |
+
# Redidual dense blocks and dense feature fusion
|
58 |
+
RDBs_out = []
|
59 |
+
for i in range(D):
|
60 |
+
x = RDB(self.G0, G, C)(x)
|
61 |
+
RDBs_out.append(x)
|
62 |
+
|
63 |
+
x = jnp.concatenate(RDBs_out, -1)
|
64 |
+
|
65 |
+
# Global Feature Fusion
|
66 |
+
x = nn.Sequential([
|
67 |
+
nn.Conv(self.G0, (1, 1)),
|
68 |
+
nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))
|
69 |
+
])(x)
|
70 |
+
|
71 |
+
x = x + f_1
|
72 |
+
return x
|
model/swin_ir.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable, Optional, Iterable
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
import flax.linen as nn
|
8 |
+
from jaxtyping import Array
|
9 |
+
|
10 |
+
|
11 |
+
def trunc_normal(mean=0., std=1., a=-2., b=2., dtype=jnp.float32) -> Callable:
|
12 |
+
"""Truncated normal initialization function"""
|
13 |
+
|
14 |
+
def init(key, shape, dtype=dtype) -> Array:
|
15 |
+
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/weight_init.py
|
16 |
+
def norm_cdf(x):
|
17 |
+
# Computes standard normal cumulative distribution function
|
18 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
19 |
+
|
20 |
+
l = norm_cdf((a - mean) / std)
|
21 |
+
u = norm_cdf((b - mean) / std)
|
22 |
+
out = jax.random.uniform(key, shape, dtype=dtype, minval=2 * l - 1, maxval=2 * u - 1)
|
23 |
+
out = jax.scipy.special.erfinv(out) * std * math.sqrt(2.) + mean
|
24 |
+
return jnp.clip(out, a, b)
|
25 |
+
|
26 |
+
return init
|
27 |
+
|
28 |
+
|
29 |
+
def Dense(features, use_bias=True, kernel_init=trunc_normal(std=.02), bias_init=nn.initializers.zeros):
|
30 |
+
return nn.Dense(features, use_bias=use_bias, kernel_init=kernel_init, bias_init=bias_init)
|
31 |
+
|
32 |
+
|
33 |
+
def LayerNorm():
|
34 |
+
"""torch LayerNorm uses larger epsilon by default"""
|
35 |
+
return nn.LayerNorm(epsilon=1e-05)
|
36 |
+
|
37 |
+
|
38 |
+
class Mlp(nn.Module):
|
39 |
+
|
40 |
+
in_features: int
|
41 |
+
hidden_features: int = None
|
42 |
+
out_features: int = None
|
43 |
+
act_layer: Callable = nn.gelu
|
44 |
+
drop: float = 0.0
|
45 |
+
|
46 |
+
@nn.compact
|
47 |
+
def __call__(self, x, training: bool):
|
48 |
+
x = nn.Dense(self.hidden_features or self.in_features)(x)
|
49 |
+
x = self.act_layer(x)
|
50 |
+
x = nn.Dropout(self.drop, deterministic=not training)(x)
|
51 |
+
x = nn.Dense(self.out_features or self.in_features)(x)
|
52 |
+
x = nn.Dropout(self.drop, deterministic=not training)(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
def window_partition(x, window_size: int):
|
57 |
+
"""
|
58 |
+
Args:
|
59 |
+
x: (B, H, W, C)
|
60 |
+
window_size (int): window size
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
windows: (num_windows*B, window_size, window_size, C)
|
64 |
+
"""
|
65 |
+
B, H, W, C = x.shape
|
66 |
+
x = x.reshape((B, H // window_size, window_size, W // window_size, window_size, C))
|
67 |
+
windows = x.transpose((0, 1, 3, 2, 4, 5)).reshape((-1, window_size, window_size, C))
|
68 |
+
return windows
|
69 |
+
|
70 |
+
|
71 |
+
def window_reverse(windows, window_size: int, H: int, W: int):
|
72 |
+
"""
|
73 |
+
Args:
|
74 |
+
windows: (num_windows*B, window_size, window_size, C)
|
75 |
+
window_size (int): Window size
|
76 |
+
H (int): Height of image
|
77 |
+
W (int): Width of image
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
x: (B, H, W, C)
|
81 |
+
"""
|
82 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
83 |
+
x = windows.reshape((B, H // window_size, W // window_size, window_size, window_size, -1))
|
84 |
+
x = x.transpose((0, 1, 3, 2, 4, 5)).reshape((B, H, W, -1))
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
class DropPath(nn.Module):
|
89 |
+
"""
|
90 |
+
Implementation referred from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
91 |
+
"""
|
92 |
+
|
93 |
+
dropout_prob: float = 0.1
|
94 |
+
deterministic: Optional[bool] = None
|
95 |
+
|
96 |
+
@nn.compact
|
97 |
+
def __call__(self, input, training):
|
98 |
+
if not training:
|
99 |
+
return input
|
100 |
+
keep_prob = 1 - self.dropout_prob
|
101 |
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1)
|
102 |
+
rng = self.make_rng("dropout")
|
103 |
+
random_tensor = keep_prob + jax.random.uniform(rng, shape)
|
104 |
+
random_tensor = jnp.floor(random_tensor)
|
105 |
+
return jnp.divide(input, keep_prob) * random_tensor
|
106 |
+
|
107 |
+
|
108 |
+
class WindowAttention(nn.Module):
|
109 |
+
dim: int
|
110 |
+
window_size: Iterable[int]
|
111 |
+
num_heads: int
|
112 |
+
qkv_bias: bool = True
|
113 |
+
qk_scale: Optional[float] = None
|
114 |
+
att_drop: float = 0.0
|
115 |
+
proj_drop: float = 0.0
|
116 |
+
|
117 |
+
def make_rel_pos_index(self):
|
118 |
+
h_indices = np.arange(0, self.window_size[0])
|
119 |
+
w_indices = np.arange(0, self.window_size[1])
|
120 |
+
indices = np.stack(np.meshgrid(w_indices, h_indices, indexing="ij"))
|
121 |
+
flatten_indices = np.reshape(indices, (2, -1))
|
122 |
+
relative_indices = flatten_indices[:, :, None] - flatten_indices[:, None, :]
|
123 |
+
relative_indices = np.transpose(relative_indices, (1, 2, 0))
|
124 |
+
relative_indices[:, :, 0] += self.window_size[0] - 1
|
125 |
+
relative_indices[:, :, 1] += self.window_size[1] - 1
|
126 |
+
relative_indices[:, :, 0] *= 2 * self.window_size[1] - 1
|
127 |
+
relative_pos_index = np.sum(relative_indices, -1)
|
128 |
+
return relative_pos_index
|
129 |
+
|
130 |
+
@nn.compact
|
131 |
+
def __call__(self, inputs, mask, training):
|
132 |
+
rpbt = self.param(
|
133 |
+
"relative_position_bias_table",
|
134 |
+
trunc_normal(std=.02),
|
135 |
+
(
|
136 |
+
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
|
137 |
+
self.num_heads,
|
138 |
+
),
|
139 |
+
)
|
140 |
+
|
141 |
+
#relative_pos_index = self.variable(
|
142 |
+
# "variables", "relative_position_index", self.get_rel_pos_index
|
143 |
+
#)
|
144 |
+
|
145 |
+
batch, n, channels = inputs.shape
|
146 |
+
qkv = nn.Dense(self.dim * 3, use_bias=self.qkv_bias, name="qkv")(inputs)
|
147 |
+
qkv = qkv.reshape(batch, n, 3, self.num_heads, channels // self.num_heads)
|
148 |
+
qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
|
149 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
150 |
+
|
151 |
+
scale = self.qk_scale or (self.dim // self.num_heads) ** -0.5
|
152 |
+
q = q * scale
|
153 |
+
att = q @ jnp.swapaxes(k, -2, -1)
|
154 |
+
|
155 |
+
rel_pos_bias = jnp.reshape(
|
156 |
+
rpbt[np.reshape(self.make_rel_pos_index(), (-1))],
|
157 |
+
(
|
158 |
+
self.window_size[0] * self.window_size[1],
|
159 |
+
self.window_size[0] * self.window_size[1],
|
160 |
+
-1,
|
161 |
+
),
|
162 |
+
)
|
163 |
+
rel_pos_bias = jnp.transpose(rel_pos_bias, (2, 0, 1))
|
164 |
+
att += jnp.expand_dims(rel_pos_bias, 0)
|
165 |
+
|
166 |
+
if mask is not None:
|
167 |
+
att = jnp.reshape(
|
168 |
+
att, (batch // mask.shape[0], mask.shape[0], self.num_heads, n, n)
|
169 |
+
)
|
170 |
+
att = att + jnp.expand_dims(jnp.expand_dims(mask, 1), 0)
|
171 |
+
att = jnp.reshape(att, (-1, self.num_heads, n, n))
|
172 |
+
att = jax.nn.softmax(att)
|
173 |
+
|
174 |
+
else:
|
175 |
+
att = jax.nn.softmax(att)
|
176 |
+
|
177 |
+
att = nn.Dropout(self.att_drop)(att, deterministic=not training)
|
178 |
+
|
179 |
+
x = jnp.reshape(jnp.swapaxes(att @ v, 1, 2), (batch, n, channels))
|
180 |
+
x = nn.Dense(self.dim, name="proj")(x)
|
181 |
+
x = nn.Dropout(self.proj_drop)(x, deterministic=not training)
|
182 |
+
return x
|
183 |
+
|
184 |
+
|
185 |
+
class SwinTransformerBlock(nn.Module):
|
186 |
+
|
187 |
+
dim: int
|
188 |
+
input_resolution: tuple[int]
|
189 |
+
num_heads: int
|
190 |
+
window_size: int = 7
|
191 |
+
shift_size: int = 0
|
192 |
+
mlp_ratio: float = 4.
|
193 |
+
qkv_bias: bool = True
|
194 |
+
qk_scale: Optional[float] = None
|
195 |
+
drop: float = 0.
|
196 |
+
attn_drop: float = 0.
|
197 |
+
drop_path: float = 0.
|
198 |
+
act_layer: Callable = nn.activation.gelu
|
199 |
+
norm_layer: Callable = LayerNorm
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
def make_att_mask(shift_size, window_size, height, width):
|
203 |
+
if shift_size > 0:
|
204 |
+
mask = jnp.zeros([1, height, width, 1])
|
205 |
+
h_slices = (
|
206 |
+
slice(0, -window_size),
|
207 |
+
slice(-window_size, -shift_size),
|
208 |
+
slice(-shift_size, None),
|
209 |
+
)
|
210 |
+
w_slices = (
|
211 |
+
slice(0, -window_size),
|
212 |
+
slice(-window_size, -shift_size),
|
213 |
+
slice(-shift_size, None),
|
214 |
+
)
|
215 |
+
|
216 |
+
count = 0
|
217 |
+
for h in h_slices:
|
218 |
+
for w in w_slices:
|
219 |
+
mask = mask.at[:, h, w, :].set(count)
|
220 |
+
count += 1
|
221 |
+
|
222 |
+
mask_windows = window_partition(mask, window_size)
|
223 |
+
mask_windows = jnp.reshape(mask_windows, (-1, window_size * window_size))
|
224 |
+
att_mask = jnp.expand_dims(mask_windows, 1) - jnp.expand_dims(mask_windows, 2)
|
225 |
+
att_mask = jnp.where(att_mask != 0.0, float(-100.0), att_mask)
|
226 |
+
att_mask = jnp.where(att_mask == 0.0, float(0.0), att_mask)
|
227 |
+
else:
|
228 |
+
att_mask = None
|
229 |
+
|
230 |
+
return att_mask
|
231 |
+
|
232 |
+
@nn.compact
|
233 |
+
def __call__(self, x, x_size, training):
|
234 |
+
H, W = x_size
|
235 |
+
B, L, C = x.shape
|
236 |
+
|
237 |
+
if min(self.input_resolution) <= self.window_size:
|
238 |
+
# if window size is larger than input resolution, we don't partition windows
|
239 |
+
self.shift_size = 0
|
240 |
+
self.window_size = min(self.input_resolution)
|
241 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
242 |
+
|
243 |
+
shortcut = x
|
244 |
+
x = self.norm_layer()(x)
|
245 |
+
x = x.reshape((B, H, W, C))
|
246 |
+
|
247 |
+
# cyclic shift
|
248 |
+
if self.shift_size > 0:
|
249 |
+
shifted_x = jnp.roll(x, (-self.shift_size, -self.shift_size), axis=(1, 2))
|
250 |
+
else:
|
251 |
+
shifted_x = x
|
252 |
+
|
253 |
+
# partition windows
|
254 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
255 |
+
x_windows = x_windows.reshape((-1, self.window_size * self.window_size, C)) # nW*B, window_size*window_size, C
|
256 |
+
|
257 |
+
#attn_mask = self.variable(
|
258 |
+
# "variables",
|
259 |
+
# "attn_mask",
|
260 |
+
# self.get_att_mask,
|
261 |
+
# self.shift_size,
|
262 |
+
# self.window_size,
|
263 |
+
# self.input_resolution[0],
|
264 |
+
# self.input_resolution[1]
|
265 |
+
#)
|
266 |
+
|
267 |
+
attn_mask = self.make_att_mask(self.shift_size, self.window_size, *self.input_resolution)
|
268 |
+
|
269 |
+
attn = WindowAttention(self.dim, (self.window_size, self.window_size), self.num_heads,
|
270 |
+
self.qkv_bias, self.qk_scale, self.attn_drop, self.drop)
|
271 |
+
if self.input_resolution == x_size:
|
272 |
+
attn_windows = attn(x_windows, attn_mask, training) # nW*B, window_size*window_size, C
|
273 |
+
else:
|
274 |
+
# test time
|
275 |
+
assert not training
|
276 |
+
test_mask = self.make_att_mask(self.shift_size, self.window_size, *x_size)
|
277 |
+
attn_windows = attn(x_windows, test_mask, training=False)
|
278 |
+
|
279 |
+
# merge windows
|
280 |
+
attn_windows = attn_windows.reshape((-1, self.window_size, self.window_size, C))
|
281 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
282 |
+
|
283 |
+
# reverse cyclic shift
|
284 |
+
if self.shift_size > 0:
|
285 |
+
x = jnp.roll(shifted_x, (self.shift_size, self.shift_size), axis=(1, 2))
|
286 |
+
else:
|
287 |
+
x = shifted_x
|
288 |
+
|
289 |
+
x = x.reshape((B, H * W, C))
|
290 |
+
|
291 |
+
# FFN
|
292 |
+
x = shortcut + DropPath(self.drop_path)(x, training)
|
293 |
+
|
294 |
+
norm = self.norm_layer()(x)
|
295 |
+
mlp = Mlp(in_features=self.dim, hidden_features=int(self.dim * self.mlp_ratio),
|
296 |
+
act_layer=self.act_layer, drop=self.drop)(norm, training)
|
297 |
+
x = x + DropPath(self.drop_path)(mlp, training)
|
298 |
+
|
299 |
+
return x
|
300 |
+
|
301 |
+
|
302 |
+
class PatchMerging(nn.Module):
|
303 |
+
inp_res: Iterable[int]
|
304 |
+
dim: int
|
305 |
+
norm_layer: Callable = LayerNorm
|
306 |
+
|
307 |
+
@nn.compact
|
308 |
+
def __call__(self, inputs):
|
309 |
+
batch, n, channels = inputs.shape
|
310 |
+
height, width = self.inp_res[0], self.inp_res[1]
|
311 |
+
x = jnp.reshape(inputs, (batch, height, width, channels))
|
312 |
+
|
313 |
+
x0 = x[:, 0::2, 0::2, :]
|
314 |
+
x1 = x[:, 1::2, 0::2, :]
|
315 |
+
x2 = x[:, 0::2, 1::2, :]
|
316 |
+
x3 = x[:, 1::2, 1::2, :]
|
317 |
+
|
318 |
+
x = jnp.concatenate([x0, x1, x2, x3], axis=-1)
|
319 |
+
x = jnp.reshape(x, (batch, -1, 4 * channels))
|
320 |
+
x = self.norm_layer()(x)
|
321 |
+
x = nn.Dense(2 * self.dim, use_bias=False)(x)
|
322 |
+
return x
|
323 |
+
|
324 |
+
|
325 |
+
class BasicLayer(nn.Module):
|
326 |
+
|
327 |
+
dim: int
|
328 |
+
input_resolution: int
|
329 |
+
depth: int
|
330 |
+
num_heads: int
|
331 |
+
window_size: int
|
332 |
+
mlp_ratio: float = 4.
|
333 |
+
qkv_bias: bool = True
|
334 |
+
qk_scale: Optional[float] = None
|
335 |
+
drop: float = 0.
|
336 |
+
attn_drop: float = 0.
|
337 |
+
drop_path: float = 0.
|
338 |
+
norm_layer: Callable = LayerNorm
|
339 |
+
downsample: Optional[Callable] = None
|
340 |
+
|
341 |
+
@nn.compact
|
342 |
+
def __call__(self, x, x_size, training):
|
343 |
+
for i in range(self.depth):
|
344 |
+
x = SwinTransformerBlock(
|
345 |
+
self.dim,
|
346 |
+
self.input_resolution,
|
347 |
+
self.num_heads,
|
348 |
+
self.window_size,
|
349 |
+
0 if (i % 2 == 0) else self.window_size // 2,
|
350 |
+
self.mlp_ratio,
|
351 |
+
self.qkv_bias,
|
352 |
+
self.qk_scale,
|
353 |
+
self.drop,
|
354 |
+
self.attn_drop,
|
355 |
+
self.drop_path[i] if isinstance(self.drop_path, (list, tuple)) else self.drop_path,
|
356 |
+
norm_layer=self.norm_layer
|
357 |
+
)(x, x_size, training)
|
358 |
+
|
359 |
+
if self.downsample is not None:
|
360 |
+
x = self.downsample(self.input_resolution, dim=self.dim, norm_layer=self.norm_layer)(x)
|
361 |
+
|
362 |
+
return x
|
363 |
+
|
364 |
+
|
365 |
+
class RSTB(nn.Module):
|
366 |
+
|
367 |
+
dim: int
|
368 |
+
input_resolution: int
|
369 |
+
depth: int
|
370 |
+
num_heads: int
|
371 |
+
window_size: int
|
372 |
+
mlp_ratio: float = 4.
|
373 |
+
qkv_bias: bool = True
|
374 |
+
qk_scale: Optional[float] = None
|
375 |
+
drop: float = 0.
|
376 |
+
attn_drop: float = 0.
|
377 |
+
drop_path: float = 0.
|
378 |
+
norm_layer: Callable = LayerNorm
|
379 |
+
downsample: Optional[Callable] = None
|
380 |
+
img_size: int = 224,
|
381 |
+
patch_size: int = 4,
|
382 |
+
resi_connection: str = '1conv'
|
383 |
+
|
384 |
+
@nn.compact
|
385 |
+
def __call__(self, x, x_size, training):
|
386 |
+
res = x
|
387 |
+
x = BasicLayer(dim=self.dim,
|
388 |
+
input_resolution=self.input_resolution,
|
389 |
+
depth=self.depth,
|
390 |
+
num_heads=self.num_heads,
|
391 |
+
window_size=self.window_size,
|
392 |
+
mlp_ratio=self.mlp_ratio,
|
393 |
+
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
|
394 |
+
drop=self.drop, attn_drop=self.attn_drop,
|
395 |
+
drop_path=self.drop_path,
|
396 |
+
norm_layer=self.norm_layer,
|
397 |
+
downsample=self.downsample)(x, x_size, training)
|
398 |
+
|
399 |
+
x = PatchUnEmbed(embed_dim=self.dim)(x, x_size)
|
400 |
+
|
401 |
+
# resi_connection == '1conv':
|
402 |
+
x = nn.Conv(self.dim, (3, 3))(x)
|
403 |
+
|
404 |
+
x = PatchEmbed()(x)
|
405 |
+
|
406 |
+
return x + res
|
407 |
+
|
408 |
+
|
409 |
+
class PatchEmbed(nn.Module):
|
410 |
+
norm_layer: Optional[Callable] = None
|
411 |
+
|
412 |
+
@nn.compact
|
413 |
+
def __call__(self, x):
|
414 |
+
x = x.reshape((x.shape[0], -1, x.shape[-1])) # B Ph Pw C -> B Ph*Pw C
|
415 |
+
if self.norm_layer is not None:
|
416 |
+
x = self.norm_layer()(x)
|
417 |
+
return x
|
418 |
+
|
419 |
+
|
420 |
+
class PatchUnEmbed(nn.Module):
|
421 |
+
embed_dim: int = 96
|
422 |
+
|
423 |
+
@nn.compact
|
424 |
+
def __call__(self, x, x_size):
|
425 |
+
B, HW, C = x.shape
|
426 |
+
x = x.reshape((B, x_size[0], x_size[1], self.embed_dim))
|
427 |
+
return x
|
428 |
+
|
429 |
+
|
430 |
+
class SwinIR(nn.Module):
|
431 |
+
r""" SwinIR JAX implementation
|
432 |
+
Args:
|
433 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
434 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
435 |
+
in_chans (int): Number of input image channels. Default: 3
|
436 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
437 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
438 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
439 |
+
window_size (int): Window size. Default: 7
|
440 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
441 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
442 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
443 |
+
drop_rate (float): Dropout rate. Default: 0
|
444 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
445 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
446 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
447 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
448 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
449 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
450 |
+
img_range: Image range. 1. or 25I think5.
|
451 |
+
"""
|
452 |
+
|
453 |
+
img_size: int = 48
|
454 |
+
patch_size: int = 1
|
455 |
+
in_chans: int = 3
|
456 |
+
embed_dim: int = 180
|
457 |
+
depths: tuple = (6, 6, 6, 6, 6, 6)
|
458 |
+
num_heads: tuple = (6, 6, 6, 6, 6, 6)
|
459 |
+
window_size: int = 8
|
460 |
+
mlp_ratio: float = 2.
|
461 |
+
qkv_bias: bool = True
|
462 |
+
qk_scale: Optional[float] = None
|
463 |
+
drop_rate: float = 0.
|
464 |
+
attn_drop_rate: float = 0.
|
465 |
+
drop_path_rate: float = 0.1
|
466 |
+
norm_layer: Callable = LayerNorm
|
467 |
+
ape: bool = False
|
468 |
+
patch_norm: bool = True
|
469 |
+
upscale: int = 2
|
470 |
+
img_range: float = 1.
|
471 |
+
num_feat: int = 64
|
472 |
+
|
473 |
+
def pad(self, x):
|
474 |
+
_, h, w, _ = x.shape
|
475 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
476 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
477 |
+
x = jnp.pad(x, ((0, 0), (0, mod_pad_h), (0, mod_pad_w), (0, 0)), 'reflect')
|
478 |
+
return x
|
479 |
+
|
480 |
+
@nn.compact
|
481 |
+
def __call__(self, x, training):
|
482 |
+
_, h_before, w_before, _ = x.shape
|
483 |
+
x = self.pad(x)
|
484 |
+
_, h, w, _ = x.shape
|
485 |
+
patches_resolution = [self.img_size // self.patch_size] * 2
|
486 |
+
num_patches = patches_resolution[0] * patches_resolution[1]
|
487 |
+
|
488 |
+
# conv_first
|
489 |
+
x = nn.Conv(self.embed_dim, (3, 3))(x)
|
490 |
+
res = x
|
491 |
+
|
492 |
+
# feature extraction
|
493 |
+
x_size = (h, w)
|
494 |
+
x = PatchEmbed(self.norm_layer if self.patch_norm else None)(x)
|
495 |
+
|
496 |
+
if self.ape:
|
497 |
+
absolute_pos_embed = \
|
498 |
+
self.param('ape', trunc_normal(std=.02), (1, num_patches, self.embed_dim))
|
499 |
+
x = x + absolute_pos_embed
|
500 |
+
|
501 |
+
x = nn.Dropout(self.drop_rate, deterministic=not training)(x)
|
502 |
+
|
503 |
+
dpr = [x.item() for x in np.linspace(0, self.drop_path_rate, sum(self.depths))]
|
504 |
+
for i_layer in range(len(self.depths)):
|
505 |
+
x = RSTB(
|
506 |
+
dim=self.embed_dim,
|
507 |
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
508 |
+
depth=self.depths[i_layer],
|
509 |
+
num_heads=self.num_heads[i_layer],
|
510 |
+
window_size=self.window_size,
|
511 |
+
mlp_ratio=self.mlp_ratio,
|
512 |
+
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
|
513 |
+
drop=self.drop_rate, attn_drop=self.attn_drop_rate,
|
514 |
+
drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
|
515 |
+
norm_layer=self.norm_layer,
|
516 |
+
downsample=None,
|
517 |
+
img_size=self.img_size,
|
518 |
+
patch_size=self.patch_size)(x, x_size, training)
|
519 |
+
|
520 |
+
x = self.norm_layer()(x) # B L C
|
521 |
+
x = PatchUnEmbed(self.embed_dim)(x, x_size)
|
522 |
+
|
523 |
+
# conv_after_body
|
524 |
+
x = nn.Conv(self.embed_dim, (3, 3))(x)
|
525 |
+
x = x + res
|
526 |
+
|
527 |
+
# conv_before_upsample
|
528 |
+
x = nn.activation.leaky_relu(nn.Conv(self.num_feat, (3, 3))(x))
|
529 |
+
|
530 |
+
# revert padding
|
531 |
+
x = x[:, :-(h - h_before) or None, :-(w - w_before) or None]
|
532 |
+
return x
|
model/tail.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import flax.linen as nn
|
2 |
+
|
3 |
+
from .convnext import ConvNeXt
|
4 |
+
from .swin_ir import SwinIR
|
5 |
+
|
6 |
+
|
7 |
+
def build_tail(size: str):
|
8 |
+
""" Convenience function to build the three tails described in the paper. """
|
9 |
+
if size == 'air':
|
10 |
+
return lambda x, _: x
|
11 |
+
elif size == 'plus':
|
12 |
+
blocks = [(64, 3, True)] * 6 + [(96, 3, True)] * 7 + [(128, 3, True)] * 3
|
13 |
+
return ConvNeXt(blocks)
|
14 |
+
elif size == 'pro':
|
15 |
+
return SwinIR(depths=[7, 6], num_heads=[6, 6])
|
16 |
+
else:
|
17 |
+
raise NotImplementedError('size: ' + size)
|
18 |
+
|
model/thera.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import jax
|
4 |
+
from flax.core import unfreeze, freeze
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import flax.linen as nn
|
7 |
+
from jaxtyping import Array, ArrayLike, PyTree
|
8 |
+
|
9 |
+
from .edsr import EDSR
|
10 |
+
from .rdn import RDN
|
11 |
+
from .hyper import Hypernetwork
|
12 |
+
from .tail import build_tail
|
13 |
+
from .init import uniform_between, linear_up
|
14 |
+
from utils import make_grid, interpolate_grid, repeat_vmap
|
15 |
+
|
16 |
+
|
17 |
+
class Thermal(nn.Module):
|
18 |
+
w0_scale: float = 1.
|
19 |
+
|
20 |
+
@nn.compact
|
21 |
+
def __call__(self, x: ArrayLike, t, norm, k) -> Array:
|
22 |
+
phase = self.param('phase', nn.initializers.uniform(.5), x.shape[-1:])
|
23 |
+
return jnp.sin(self.w0_scale * x + phase) * jnp.exp(-(self.w0_scale * norm)**2 * k * t)
|
24 |
+
|
25 |
+
|
26 |
+
class TheraField(nn.Module):
|
27 |
+
dim_hidden: int
|
28 |
+
dim_out: int
|
29 |
+
w0: float = 1.
|
30 |
+
c: float = 6.
|
31 |
+
|
32 |
+
@nn.compact
|
33 |
+
def __call__(self, x: ArrayLike, t: ArrayLike, k: ArrayLike, components: ArrayLike) -> Array:
|
34 |
+
# coordinate projection according to shared components ("first layer")
|
35 |
+
x = x @ components
|
36 |
+
|
37 |
+
# thermal activations
|
38 |
+
norm = jnp.linalg.norm(components, axis=-2)
|
39 |
+
x = Thermal(self.w0)(x, t, norm, k)
|
40 |
+
|
41 |
+
# linear projection from hidden to output space ("second layer")
|
42 |
+
w_std = math.sqrt(self.c / self.dim_hidden) / self.w0
|
43 |
+
dense_init_fn = uniform_between(-w_std, w_std)
|
44 |
+
x = nn.Dense(self.dim_out, kernel_init=dense_init_fn, use_bias=False)(x)
|
45 |
+
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class Thera:
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
hidden_dim: int,
|
54 |
+
out_dim: int,
|
55 |
+
backbone: nn.Module,
|
56 |
+
tail: nn.Module,
|
57 |
+
k_init: float = None,
|
58 |
+
components_init_scale: float = None
|
59 |
+
):
|
60 |
+
self.hidden_dim = hidden_dim
|
61 |
+
self.k_init = k_init
|
62 |
+
self.components_init_scale = components_init_scale
|
63 |
+
|
64 |
+
# single TheraField object whose `apply` method is used for all grid cells
|
65 |
+
self.field = TheraField(hidden_dim, out_dim)
|
66 |
+
|
67 |
+
# infer output size of the hypernetwork from a sample pass through the field;
|
68 |
+
# key doesnt matter as field params are only used for size inference
|
69 |
+
sample_params = self.field.init(jax.random.PRNGKey(0),
|
70 |
+
jnp.zeros((2,)), 0., 0., jnp.zeros((2, hidden_dim)))
|
71 |
+
sample_params_flat, tree_def = jax.tree_util.tree_flatten(sample_params)
|
72 |
+
param_shapes = [p.shape for p in sample_params_flat]
|
73 |
+
|
74 |
+
self.hypernet = Hypernetwork(backbone, tail, param_shapes, tree_def)
|
75 |
+
|
76 |
+
def init(self, key, sample_source) -> PyTree:
|
77 |
+
keys = jax.random.split(key, 2)
|
78 |
+
sample_coords = jnp.zeros(sample_source.shape[:-1] + (2,))
|
79 |
+
params = unfreeze(self.hypernet.init(keys[0], sample_source, sample_coords))
|
80 |
+
|
81 |
+
params['params']['k'] = jnp.array(self.k_init)
|
82 |
+
params['params']['components'] = \
|
83 |
+
linear_up(self.components_init_scale)(keys[1], (2, self.hidden_dim))
|
84 |
+
|
85 |
+
return freeze(params)
|
86 |
+
|
87 |
+
def apply_encoder(self, params: PyTree, source: ArrayLike, **kwargs) -> Array:
|
88 |
+
"""
|
89 |
+
Performs a forward pass through the hypernetwork to obtain an encoding.
|
90 |
+
"""
|
91 |
+
return self.hypernet.apply(
|
92 |
+
params, source, method=self.hypernet.get_encoding, **kwargs)
|
93 |
+
|
94 |
+
def apply_decoder(
|
95 |
+
self,
|
96 |
+
params: PyTree,
|
97 |
+
encoding: ArrayLike,
|
98 |
+
coords: ArrayLike,
|
99 |
+
t: ArrayLike,
|
100 |
+
return_jac: bool = False
|
101 |
+
) -> Array | tuple[Array, Array]:
|
102 |
+
"""
|
103 |
+
Performs a forward prediction through a grid of HxW Thera fields,
|
104 |
+
informed by `encoding`, at spatial and temporal coordinates
|
105 |
+
`coords` and `t`, respectively.
|
106 |
+
args:
|
107 |
+
params: Field parameters, shape (B, H, W, N)
|
108 |
+
encoding: Encoding tensor, shape (B, H, W, C)
|
109 |
+
coords: Spatial coordinates in [-0.5, 0.5], shape (B, H, W, 2)
|
110 |
+
t: Temporal coordinates, shape (B, 1)
|
111 |
+
"""
|
112 |
+
phi_params: PyTree = self.hypernet.apply(
|
113 |
+
params, encoding, coords, method=self.hypernet.get_params_at_coords)
|
114 |
+
|
115 |
+
# create local coordinate systems
|
116 |
+
source_grid = jnp.asarray(make_grid(encoding.shape[-3:-1]))
|
117 |
+
source_coords = jnp.tile(source_grid, (encoding.shape[0], 1, 1, 1))
|
118 |
+
interp_coords = interpolate_grid(coords, source_coords)
|
119 |
+
rel_coords = (coords - interp_coords)
|
120 |
+
rel_coords = rel_coords.at[..., 0].set(rel_coords[..., 0] * encoding.shape[-3])
|
121 |
+
rel_coords = rel_coords.at[..., 1].set(rel_coords[..., 1] * encoding.shape[-2])
|
122 |
+
|
123 |
+
# three maps over params, coords; one over t; dont map k and components
|
124 |
+
in_axes = [(0, 0, None, None, None), (0, 0, None, None, None), (0, 0, 0, None, None)]
|
125 |
+
apply_field = repeat_vmap(self.field.apply, in_axes)
|
126 |
+
out = apply_field(phi_params, rel_coords, t, params['params']['k'],
|
127 |
+
params['params']['components'])
|
128 |
+
|
129 |
+
if return_jac:
|
130 |
+
apply_jac = repeat_vmap(jax.jacrev(self.field.apply, argnums=1), in_axes)
|
131 |
+
jac = apply_jac(phi_params, rel_coords, jnp.zeros_like(t), params['params']['k'],
|
132 |
+
params['params']['components'])
|
133 |
+
return out, jac
|
134 |
+
|
135 |
+
return out
|
136 |
+
|
137 |
+
def apply(
|
138 |
+
self,
|
139 |
+
params: ArrayLike,
|
140 |
+
source: ArrayLike,
|
141 |
+
coords: ArrayLike,
|
142 |
+
t: ArrayLike,
|
143 |
+
return_jac: bool = False,
|
144 |
+
**kwargs
|
145 |
+
) -> Array:
|
146 |
+
"""
|
147 |
+
Performs a forward pass through the Thera model.
|
148 |
+
"""
|
149 |
+
encoding = self.apply_encoder(params, source, **kwargs)
|
150 |
+
out = self.apply_decoder(params, encoding, coords, t, return_jac=return_jac)
|
151 |
+
return out
|
152 |
+
|
153 |
+
|
154 |
+
def build_thera(
|
155 |
+
out_dim: int,
|
156 |
+
backbone: str,
|
157 |
+
size: str,
|
158 |
+
k_init: float = None,
|
159 |
+
components_init_scale: float = None
|
160 |
+
):
|
161 |
+
"""
|
162 |
+
Convenience function for building the three Thera sizes described in the paper.
|
163 |
+
"""
|
164 |
+
hidden_dim = 32 if size == 'air' else 512
|
165 |
+
|
166 |
+
if backbone == 'edsr-baseline':
|
167 |
+
backbone_module = EDSR(None, num_blocks=16, num_feats=64)
|
168 |
+
elif backbone == 'rdn':
|
169 |
+
backbone_module = RDN()
|
170 |
+
else:
|
171 |
+
raise NotImplementedError(backbone)
|
172 |
+
|
173 |
+
tail_module = build_tail(size)
|
174 |
+
|
175 |
+
return Thera(hidden_dim, out_dim, backbone_module, tail_module, k_init, components_init_scale)
|
requirements.txt
CHANGED
@@ -1,6 +1,37 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
|
|
5 |
diffusers
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
2 |
+
|
3 |
+
ConfigArgParse==1.7
|
4 |
+
Pillow==10.0.0
|
5 |
+
chex==0.1.7
|
6 |
diffusers
|
7 |
+
einops==0.6.1
|
8 |
+
flax==0.6.10
|
9 |
+
flaxmodels==0.1.3
|
10 |
+
jax==0.4.11
|
11 |
+
jaxlib==0.4.11+cuda11.cudnn86
|
12 |
+
jaxtyping==0.2.20
|
13 |
+
ml-dtypes==0.1.0
|
14 |
+
numpy==1.24.1
|
15 |
+
nvidia-cublas-cu11==11.11.3.6
|
16 |
+
nvidia-cuda-cupti-cu11==11.8.87
|
17 |
+
nvidia-cuda-nvcc-cu11==11.8.89
|
18 |
+
nvidia-cuda-runtime-cu11==11.8.89
|
19 |
+
nvidia-cudnn-cu11==8.9.2.26
|
20 |
+
nvidia-cufft-cu11==10.9.0.58
|
21 |
+
nvidia-cusolver-cu11==11.4.1.48
|
22 |
+
nvidia-cusparse-cu11==11.7.5.86
|
23 |
+
opt-einsum==3.3.0
|
24 |
+
optax==0.2.0
|
25 |
+
orbax-checkpoint==0.2.4
|
26 |
+
peft
|
27 |
+
scipy==1.10.1
|
28 |
+
timm==0.9.6
|
29 |
+
torch
|
30 |
+
torchvision
|
31 |
+
tqdm==4.65.0
|
32 |
+
transformers==4.46.3
|
33 |
+
wandb
|
34 |
+
|
35 |
+
gradio==4.44.1
|
36 |
+
gradio_imageslider==0.0.20
|
37 |
+
spaces
|
super_resolve.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from argparse import ArgumentParser, Namespace
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
import jax
|
7 |
+
from jax import jit
|
8 |
+
import jax.numpy as jnp
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from model import build_thera
|
13 |
+
from utils import make_grid, interpolate_grid
|
14 |
+
|
15 |
+
MEAN = jnp.array([.4488, .4371, .4040])
|
16 |
+
VAR = jnp.array([.25, .25, .25])
|
17 |
+
PATCH_SIZE = 256
|
18 |
+
|
19 |
+
|
20 |
+
def process_single(source, apply_encoder, apply_decoder, params, target_shape):
|
21 |
+
t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None]
|
22 |
+
coords_nearest = jnp.asarray(make_grid(target_shape)[None])
|
23 |
+
source_up = interpolate_grid(coords_nearest, source[None])
|
24 |
+
source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None]
|
25 |
+
|
26 |
+
encoding = apply_encoder(params, source)
|
27 |
+
coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) # global sampling coords
|
28 |
+
out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32)
|
29 |
+
|
30 |
+
for h_min in range(0, coords.shape[1], PATCH_SIZE):
|
31 |
+
h_max = min(h_min + PATCH_SIZE, coords.shape[1])
|
32 |
+
for w_min in range(0, coords.shape[2], PATCH_SIZE):
|
33 |
+
# apply decoder with one patch of coordinates
|
34 |
+
w_max = min(w_min + PATCH_SIZE, coords.shape[2])
|
35 |
+
coords_patch = coords[:, h_min:h_max, w_min:w_max]
|
36 |
+
out_patch = apply_decoder(params, encoding, coords_patch, t)
|
37 |
+
out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch)
|
38 |
+
|
39 |
+
out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
|
40 |
+
out += source_up
|
41 |
+
return out
|
42 |
+
|
43 |
+
|
44 |
+
def process(source, model, params, target_shape, do_ensemble=True):
|
45 |
+
apply_encoder = jit(model.apply_encoder)
|
46 |
+
apply_decoder = jit(model.apply_decoder)
|
47 |
+
|
48 |
+
outs = []
|
49 |
+
for i_rot in range(4 if do_ensemble else 1):
|
50 |
+
source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2))
|
51 |
+
target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape
|
52 |
+
out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_)
|
53 |
+
outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3)))
|
54 |
+
|
55 |
+
out = jnp.stack(outs).mean(0).clip(0., 1.)
|
56 |
+
return jnp.rint(out[0] * 255).astype(jnp.uint8)
|
57 |
+
|
58 |
+
|
59 |
+
def main(args: Namespace):
|
60 |
+
source = np.asarray(Image.open(args.in_file)) / 255.
|
61 |
+
|
62 |
+
if args.scale is not None:
|
63 |
+
if args.size is not None:
|
64 |
+
raise ValueError('Cannot specify both size and scale')
|
65 |
+
target_shape = (
|
66 |
+
round(source.shape[0] * args.scale),
|
67 |
+
round(source.shape[1] * args.scale),
|
68 |
+
)
|
69 |
+
elif args.size is not None:
|
70 |
+
target_shape = args.size
|
71 |
+
else:
|
72 |
+
raise ValueError('Must specify either size or scale')
|
73 |
+
|
74 |
+
with open(args.checkpoint, 'rb') as fh:
|
75 |
+
check = pickle.load(fh)
|
76 |
+
params, backbone, size = check['model'], check['backbone'], check['size']
|
77 |
+
|
78 |
+
model = build_thera(3, backbone, size)
|
79 |
+
|
80 |
+
out = process(source, model, params, target_shape, not args.no_ensemble)
|
81 |
+
|
82 |
+
Image.fromarray(np.asarray(out)).save(args.out_file)
|
83 |
+
|
84 |
+
|
85 |
+
def parse_args() -> Namespace:
|
86 |
+
parser = ArgumentParser()
|
87 |
+
parser.add_argument('in_file')
|
88 |
+
parser.add_argument('out_file')
|
89 |
+
parser.add_argument('--scale', type=float, help='Scale factor for super-resolution')
|
90 |
+
parser.add_argument('--size', type=int, nargs=2,
|
91 |
+
help='Target size (h, w), mutually exclusive with --scale')
|
92 |
+
parser.add_argument('--checkpoint', help='Path to checkpoint file')
|
93 |
+
parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble')
|
94 |
+
return parser.parse_args()
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == '__main__':
|
98 |
+
args = parse_args()
|
99 |
+
main(args)
|
utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import jax
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def repeat_vmap(fun, in_axes=[0]):
|
8 |
+
for axes in in_axes:
|
9 |
+
fun = jax.vmap(fun, in_axes=axes)
|
10 |
+
return fun
|
11 |
+
|
12 |
+
|
13 |
+
def make_grid(patch_size: int | tuple[int, int]):
|
14 |
+
if isinstance(patch_size, int):
|
15 |
+
patch_size = (patch_size, patch_size)
|
16 |
+
offset_h, offset_w = 1 / (2 * np.array(patch_size))
|
17 |
+
space_h = np.linspace(-0.5 + offset_h, 0.5 - offset_h, patch_size[0])
|
18 |
+
space_w = np.linspace(-0.5 + offset_w, 0.5 - offset_w, patch_size[1])
|
19 |
+
return np.stack(np.meshgrid(space_h, space_w, indexing='ij'), axis=-1) # [h, w]
|
20 |
+
|
21 |
+
|
22 |
+
def interpolate_grid(coords, grid, order=0):
|
23 |
+
"""
|
24 |
+
args:
|
25 |
+
coords: Tensor of shape (B, H, W, 2) with coordinates in [-0.5, 0.5]
|
26 |
+
grid: Tensor of shape (B, H', W', C)
|
27 |
+
returns:
|
28 |
+
Tensor of shape (B, H, W, C) with interpolated values
|
29 |
+
"""
|
30 |
+
# convert [-0.5, 0.5] -> [0, size], where pixel centers are expected at
|
31 |
+
# [-0.5 + 1 / (2*size), ..., 0.5 - 1 / (2*size)]
|
32 |
+
coords = coords.transpose((0, 3, 1, 2))
|
33 |
+
coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
|
34 |
+
coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
|
35 |
+
map_coordinates = partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest')
|
36 |
+
return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)
|