ds1david commited on
Commit
1eb87a5
·
1 Parent(s): cda6ad2
Files changed (13) hide show
  1. app.py +72 -118
  2. model/__init__.py +2 -0
  3. model/convnext.py +55 -0
  4. model/edsr.py +122 -0
  5. model/hyper.py +41 -0
  6. model/init.py +24 -0
  7. model/rdn.py +72 -0
  8. model/swin_ir.py +532 -0
  9. model/tail.py +18 -0
  10. model/thera.py +175 -0
  11. requirements.txt +36 -5
  12. super_resolve.py +99 -0
  13. utils.py +36 -0
app.py CHANGED
@@ -1,137 +1,91 @@
1
  import gradio as gr
2
  import torch
 
3
  import numpy as np
4
  from PIL import Image
5
- from peft import PeftModel
6
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
7
- from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
8
- from torchvision import transforms
9
-
10
- # Configurações iniciais
11
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
- TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
13
-
14
-
15
- # --- Carregamento dos Modelos ---
16
-
17
- # 1. Thera: Super Resolução
18
- def load_thera_model():
19
- # Modelo hipotético - ajuste conforme implementação real do Thera
20
- model = torch.hub.load('prs-eth/thera', 'thera', trust_repo=True)
21
- return model.to(DEVICE)
22
-
23
-
24
- # 2. Depth Map com PEFT
25
- def load_depth_model():
26
- base_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
27
- model = PeftModel.from_pretrained(base_model, "danube2024/dpt-peft-lora")
28
- return model.to(DEVICE).eval()
29
-
30
-
31
- # 3. Bas-Relief com ControlNet
32
- def load_controlnet():
33
- controlnet = ControlNetModel.from_pretrained(
34
- "danube2024/controlnet-bas-relief",
35
- torch_dtype=TORCH_DTYPE
36
- )
37
- pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
38
- "stabilityai/stable-diffusion-xl-base-1.0",
39
- controlnet=controlnet,
40
- torch_dtype=TORCH_DTYPE
41
- )
42
- pipe.load_lora_weights("danube2024/bas-relief-lora")
43
- return pipe.to(DEVICE)
44
-
45
-
46
- # --- Processamento ---
47
-
48
- def run_thera(image, model):
49
- transform = transforms.Compose([
50
- transforms.ToTensor(),
51
- transforms.Normalize([0.5], [0.5])
52
- ])
53
-
54
- input_tensor = transform(image).unsqueeze(0).to(DEVICE)
55
- with torch.no_grad():
56
- output = model(input_tensor)
57
-
58
- output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(-1, 1) * 0.5 + 0.5)
59
- return output_img
60
-
61
 
62
- def create_depth_map(image, model, feature_extractor):
63
- inputs = feature_extractor(images=image, return_tensors="pt").to(DEVICE)
64
  with torch.no_grad():
65
- outputs = model(**inputs)
66
- predicted_depth = outputs.predicted_depth
67
-
68
- prediction = torch.nn.functional.interpolate(
69
- predicted_depth.unsqueeze(1),
70
- size=image.size[::-1],
71
- mode="bicubic",
72
- align_corners=False,
73
- )
74
- return prediction.squeeze().cpu().numpy()
75
-
76
-
77
- def create_bas_relief(prompt, image, depth_map, pipe):
78
- control_image = Image.fromarray((depth_map * 255).astype(np.uint8))
79
-
80
- image = image.resize((1024, 1024))
81
- control_image = control_image.resize((1024, 1024))
82
-
83
- result = pipe(
84
- prompt=prompt,
85
- image=image,
86
- control_image=control_image,
87
- strength=0.8,
88
- num_inference_steps=30
89
- ).images[0]
90
 
91
- return result
 
 
 
 
92
 
 
93
 
94
- # --- Interface Gradio ---
95
 
96
- with gr.Blocks() as app:
97
- gr.Markdown("# 🖼️ Super Resolução + Depth Map + Bas-Relief")
 
98
 
99
  with gr.Row():
100
  with gr.Column():
101
- input_image = gr.Image(type="pil", label="Imagem de Entrada")
102
- prompt = gr.Textbox("high quality bas-relief sculpture, intricate details")
103
- submit_btn = gr.Button("Processar")
 
104
 
105
  with gr.Column():
106
- upscaled_output = gr.Image(label="Imagem Super Resolvida")
107
- depth_output = gr.Image(label="Mapa de Profundidade")
108
- basrelief_output = gr.Image(label="Resultado Bas-Relief")
109
-
110
-
111
- def process(image, prompt):
112
- # Carregar modelos
113
- thera_model = load_thera_model()
114
- depth_model = load_depth_model()
115
- feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
116
- basrelief_pipe = load_controlnet()
117
-
118
- # 1. Super Resolução
119
- upscaled = run_thera(image, thera_model)
120
-
121
- # 2. Depth Map
122
- depth = create_depth_map(upscaled, depth_model, feature_extractor)
123
- depth_normalized = (depth - depth.min()) / (depth.max() - depth.min())
124
-
125
- # 3. Bas-Relief
126
- basrelief = create_bas_relief(prompt, upscaled, depth_normalized, basrelief_pipe)
127
-
128
- return upscaled, depth_normalized, basrelief
129
-
130
-
131
- submit_btn.click(
132
- process,
133
- inputs=[input_image, prompt],
134
- outputs=[upscaled_output, depth_output, basrelief_output]
135
  )
136
 
137
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ import jax
4
  import numpy as np
5
  from PIL import Image
6
+ from diffusers import StableDiffusionXLImg2ImgPipeline
7
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
8
+ from super_resolve import process as thera_process # Assume imports do Thera
9
+
10
+ # Configurações
11
+ DEVICE = "cpu" # ou "cuda" se disponível
12
+ JAX_DEVICE = jax.devices("cpu")[0] # Usar CPU para JAX
13
+
14
+ # 1. Carregar modelos do Thera (EDSR/RDN)
15
+ # (Implementar conforme código original do Thera)
16
+ model_edsr, params_edsr = None, None # Carregar usando pickle/HF Hub
17
+
18
+ # 2. Carregar SDXL Img2Img + LoRA
19
+ print("Carregando SDXL Img2Img com LoRA...")
20
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
21
+ "stabilityai/stable-diffusion-xl-base-1.0",
22
+ torch_dtype=torch.float32
23
+ ).to(DEVICE)
24
+ pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
25
+
26
+ # 3. Carregar modelo de profundidade
27
+ print("Carregando DPT...")
28
+ feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
29
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(DEVICE)
30
+
31
+
32
+ def enhance_depth_map(depth_arr):
33
+ depth_normalized = (depth_arr - depth_arr.min()) / (depth_arr.max() - depth_arr.min() + 1e-8)
34
+ return Image.fromarray((depth_normalized * 255).astype(np.uint8))
35
+
36
+
37
+ def full_pipeline(image, prompt, scale_factor=2.0):
38
+ # 1. Super Resolução com Thera
39
+ source = np.array(image) / 255.0
40
+ target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
41
+ upscaled = thera_process(source, model_edsr, params_edsr, target_shape, do_ensemble=True)
42
+ upscaled_pil = Image.fromarray((upscaled * 255).astype(np.uint8))
43
+
44
+ # 2. Gerar Bas-Relief com SDXL Img2Img
45
+ full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief"
46
+ bas_relief = pipe(
47
+ prompt=full_prompt,
48
+ image=upscaled_pil,
49
+ strength=0.7,
50
+ num_inference_steps=25,
51
+ guidance_scale=7.5
52
+ ).images[0]
 
 
 
 
 
 
 
 
 
53
 
54
+ # 3. Calcular Depth Map
55
+ inputs = feature_extractor(bas_relief, return_tensors="pt").to(DEVICE)
56
  with torch.no_grad():
57
+ outputs = depth_model(**inputs)
58
+ depth = outputs.predicted_depth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ depth_map = torch.nn.functional.interpolate(
61
+ depth.unsqueeze(1),
62
+ size=bas_relief.size[::-1],
63
+ mode="bicubic"
64
+ ).squeeze().cpu().numpy()
65
 
66
+ return upscaled_pil, bas_relief, enhance_depth_map(depth_map)
67
 
 
68
 
69
+ # Interface Gradio
70
+ with gr.Blocks(title="Super Resolução + Bas-Relief") as app:
71
+ gr.Markdown("## 📈 Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
72
 
73
  with gr.Row():
74
  with gr.Column():
75
+ img_input = gr.Image(type="pil", label="Imagem de Entrada")
76
+ prompt = gr.Textbox("ancient sculpture, marble", label="Descrição do Relevo")
77
+ scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
78
+ btn = gr.Button("Processar")
79
 
80
  with gr.Column():
81
+ img_upscaled = gr.Image(label="Imagem Super Resolvida")
82
+ img_basrelief = gr.Image(label="Relevo Escultural")
83
+ img_depth = gr.Image(label="Mapa de Profundidade")
84
+
85
+ btn.click(
86
+ full_pipeline,
87
+ inputs=[img_input, prompt, scale],
88
+ outputs=[img_upscaled, img_basrelief, img_depth]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  )
90
 
91
  if __name__ == "__main__":
model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .hyper import Hypernetwork
2
+ from .thera import build_thera
model/convnext.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flax.linen as nn
2
+ from jaxtyping import Array, ArrayLike
3
+
4
+
5
+ class ConvNeXtBlock(nn.Module):
6
+ """ConvNext block. See Fig.4 in "A ConvNet for the 2020s" by Liu et al.
7
+
8
+ https://openaccess.thecvf.com/content/CVPR2022/papers/Liu_A_ConvNet_for_the_2020s_CVPR_2022_paper.pdf
9
+ """
10
+ n_dims: int = 64
11
+ kernel_size: int = 3 # 7 in the paper's version
12
+ group_features: bool = False
13
+
14
+ def setup(self) -> None:
15
+ self.residual = nn.Sequential([
16
+ nn.Conv(self.n_dims, kernel_size=(self.kernel_size, self.kernel_size), use_bias=False,
17
+ feature_group_count=self.n_dims if self.group_features else 1),
18
+ nn.LayerNorm(),
19
+ nn.Conv(4 * self.n_dims, kernel_size=(1, 1)),
20
+ nn.gelu,
21
+ nn.Conv(self.n_dims, kernel_size=(1, 1)),
22
+ ])
23
+
24
+ def __call__(self, x: ArrayLike) -> Array:
25
+ return x + self.residual(x)
26
+
27
+
28
+ class Projection(nn.Module):
29
+ n_dims: int
30
+
31
+ @nn.compact
32
+ def __call__(self, x: ArrayLike) -> Array:
33
+ x = nn.LayerNorm()(x)
34
+ x = nn.Conv(self.n_dims, (1, 1))(x)
35
+ return x
36
+
37
+
38
+ class ConvNeXt(nn.Module):
39
+ block_defs: list[tuple]
40
+
41
+ def setup(self) -> None:
42
+ layers = []
43
+ current_size = self.block_defs[0][0]
44
+ for block_def in self.block_defs:
45
+ if block_def[0] != current_size:
46
+ layers.append(Projection(block_def[0]))
47
+ layers.append(ConvNeXtBlock(*block_def))
48
+ current_size = block_def[0]
49
+ self.layers = layers
50
+
51
+ def __call__(self, x: ArrayLike, _: bool) -> Array:
52
+ for layer in self.layers:
53
+ x = layer(x)
54
+ return x
55
+
model/edsr.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://github.com/isaaccorley/jax-enhance
2
+
3
+ from functools import partial
4
+ from typing import Any, Sequence, Callable
5
+
6
+ import jax.numpy as jnp
7
+ import flax.linen as nn
8
+ from flax.core.frozen_dict import freeze
9
+ import einops
10
+
11
+
12
+ class PixelShuffle(nn.Module):
13
+ scale_factor: int
14
+
15
+ def setup(self):
16
+ self.layer = partial(
17
+ einops.rearrange,
18
+ pattern="b h w (c h2 w2) -> b (h h2) (w w2) c",
19
+ h2=self.scale_factor,
20
+ w2=self.scale_factor
21
+ )
22
+
23
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
24
+ return self.layer(x)
25
+
26
+
27
+ class ResidualBlock(nn.Module):
28
+ channels: int
29
+ kernel_size: Sequence[int]
30
+ res_scale: float
31
+ activation: Callable
32
+ dtype: Any = jnp.float32
33
+
34
+ def setup(self):
35
+ self.body = nn.Sequential([
36
+ nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype),
37
+ self.activation,
38
+ nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype),
39
+ ])
40
+
41
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
42
+ return x + self.body(x)
43
+
44
+
45
+ class UpsampleBlock(nn.Module):
46
+ num_upsamples: int
47
+ channels: int
48
+ kernel_size: Sequence[int]
49
+ dtype: Any = jnp.float32
50
+
51
+ def setup(self):
52
+ layers = []
53
+ for _ in range(self.num_upsamples):
54
+ layers.extend([
55
+ nn.Conv(features=self.channels * 2 ** 2, kernel_size=self.kernel_size, dtype=self.dtype),
56
+ PixelShuffle(scale_factor=2),
57
+ ])
58
+ self.layers = layers
59
+
60
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
61
+ for layer in self.layers:
62
+ x = layer(x)
63
+ return x
64
+
65
+
66
+ class EDSR(nn.Module):
67
+ """Enhanced Deep Residual Networks for Single Image Super-Resolution https://arxiv.org/pdf/1707.02921v1.pdf"""
68
+ scale_factor: int
69
+ channels: int = 3
70
+ num_blocks: int = 32
71
+ num_feats: int = 256
72
+ dtype: Any = jnp.float32
73
+
74
+ def setup(self):
75
+ # pre res blocks layer
76
+ self.head = nn.Sequential([nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype)])
77
+
78
+ # res blocks
79
+ res_blocks = [
80
+ ResidualBlock(channels=self.num_feats, kernel_size=(3, 3), res_scale=0.1, activation=nn.relu, dtype=self.dtype)
81
+ for i in range(self.num_blocks)
82
+ ]
83
+ res_blocks.append(nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype))
84
+ self.body = nn.Sequential(res_blocks)
85
+
86
+ def __call__(self, x: jnp.ndarray, _=None) -> jnp.ndarray:
87
+ x = self.head(x)
88
+ x = x + self.body(x)
89
+ return x
90
+
91
+
92
+ def convert_edsr_checkpoint(torch_dict, no_upsampling=True):
93
+ def convert(in_dict):
94
+ top_keys = set([k.split('.')[0] for k in in_dict.keys()])
95
+ leaves = set([k for k in in_dict.keys() if '.' not in k])
96
+
97
+ # convert leaves
98
+ out_dict = {}
99
+ for l in leaves:
100
+ if l == 'weight':
101
+ out_dict['kernel'] = jnp.asarray(in_dict[l]).transpose((2, 3, 1, 0))
102
+ elif l == 'bias':
103
+ out_dict[l] = jnp.asarray(in_dict[l])
104
+ else:
105
+ out_dict[l] = in_dict[l]
106
+
107
+ for top_key in top_keys.difference(leaves):
108
+ new_top_key = 'layers_' + top_key if top_key.isdigit() else top_key
109
+ out_dict[new_top_key] = convert(
110
+ {k[len(top_key) + 1:]: v for k, v in in_dict.items() if k.startswith(top_key)})
111
+ return out_dict
112
+
113
+ converted = convert(torch_dict)
114
+
115
+ # remove unwanted keys
116
+ if no_upsampling:
117
+ del converted['tail']
118
+
119
+ for k in ('add_mean', 'sub_mean'):
120
+ del converted[k]
121
+
122
+ return freeze(converted)
model/hyper.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import flax.linen as nn
6
+ from jaxtyping import Array, ArrayLike, PyTreeDef
7
+ import numpy as np
8
+
9
+ from utils import interpolate_grid
10
+
11
+
12
+ class Hypernetwork(nn.Module):
13
+ encoder: nn.Module
14
+ refine: nn.Module
15
+ output_params_shape: list[tuple] # e.g. [(16,), (32, 32), ...]
16
+ tree_def: PyTreeDef # used to reconstruct the parameter sets
17
+
18
+ def setup(self):
19
+ # one layer 1x1 conv to calculate field params, as in SIREN paper
20
+ output_size = sum(math.prod(s) for s in self.output_params_shape)
21
+ self.out_conv = nn.Conv(output_size, kernel_size=(1, 1), use_bias=True)
22
+
23
+ def get_encoding(self, source: ArrayLike, training=False) -> Array:
24
+ """Convenience method for whole-image evaluation"""
25
+ return self.refine(self.encoder(source, training), training)
26
+
27
+ def get_params_at_coords(self, encoding: ArrayLike, coords: ArrayLike) -> Array:
28
+ encoding = interpolate_grid(coords, encoding)
29
+ phi_params = self.out_conv(encoding)
30
+
31
+ # reshape to output params shape
32
+ phi_params = jnp.split(
33
+ phi_params, np.cumsum([math.prod(s) for s in self.output_params_shape[:-1]]), axis=-1)
34
+ phi_params = [jnp.reshape(p, p.shape[:-1] + s) for p, s in
35
+ zip(phi_params, self.output_params_shape)]
36
+
37
+ return jax.tree_util.tree_unflatten(self.tree_def, phi_params)
38
+
39
+ def __call__(self, source: ArrayLike, target_coords: ArrayLike, training=False) -> Array:
40
+ encoding = self.get_encoding(source, training)
41
+ return self.get_params_at_coords(encoding, target_coords)
model/init.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from jaxtyping import Array
6
+
7
+
8
+ def uniform_between(a: float, b: float, dtype=jnp.float32) -> Callable:
9
+ def init(key, shape, dtype=dtype) -> Array:
10
+ return jax.random.uniform(key, shape, dtype=dtype, minval=a, maxval=b)
11
+ return init
12
+
13
+
14
+ def linear_up(scale: float) -> Callable:
15
+ def init(key, shape, dtype=jnp.float32) -> Array:
16
+ assert shape[-2] == 2
17
+ keys = jax.random.split(key, 2)
18
+ norm = jnp.pi * scale * (
19
+ jax.random.uniform(keys[0], shape=(1, shape[-1])) ** .5)
20
+ theta = 2 * jnp.pi * jax.random.uniform(keys[1], shape=(1, shape[-1]))
21
+ x = norm * jnp.cos(theta)
22
+ y = norm * jnp.sin(theta)
23
+ return jnp.concatenate([x, y], axis=-2).astype(dtype)
24
+ return init
model/rdn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Residual Dense Network for Image Super-Resolution
2
+ # https://arxiv.org/abs/1802.08797
3
+ # modified from: https://github.com/thstkdgus35/EDSR-PyTorch
4
+
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+
8
+
9
+ class RDB_Conv(nn.Module):
10
+ growRate: int
11
+ kSize: int = 3
12
+
13
+ @nn.compact
14
+ def __call__(self, x):
15
+ out = nn.Sequential([
16
+ nn.Conv(self.growRate, (self.kSize, self.kSize), padding=(self.kSize-1)//2),
17
+ nn.activation.relu
18
+ ])(x)
19
+ return jnp.concatenate((x, out), -1)
20
+
21
+
22
+ class RDB(nn.Module):
23
+ growRate0: int
24
+ growRate: int
25
+ nConvLayers: int
26
+
27
+ @nn.compact
28
+ def __call__(self, x):
29
+ res = x
30
+
31
+ for c in range(self.nConvLayers):
32
+ x = RDB_Conv(self.growRate)(x)
33
+
34
+ x = nn.Conv(self.growRate0, (1, 1))(x)
35
+
36
+ return x + res
37
+
38
+
39
+ class RDN(nn.Module):
40
+ G0: int = 64
41
+ RDNkSize: int = 3
42
+ RDNconfig: str = 'B'
43
+ scale: int = 2
44
+ n_colors: int = 3
45
+
46
+ @nn.compact
47
+ def __call__(self, x, _=None):
48
+ D, C, G = {
49
+ 'A': (20, 6, 32),
50
+ 'B': (16, 8, 64),
51
+ }[self.RDNconfig]
52
+
53
+ # Shallow feature extraction
54
+ f_1 = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(x)
55
+ x = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(f_1)
56
+
57
+ # Redidual dense blocks and dense feature fusion
58
+ RDBs_out = []
59
+ for i in range(D):
60
+ x = RDB(self.G0, G, C)(x)
61
+ RDBs_out.append(x)
62
+
63
+ x = jnp.concatenate(RDBs_out, -1)
64
+
65
+ # Global Feature Fusion
66
+ x = nn.Sequential([
67
+ nn.Conv(self.G0, (1, 1)),
68
+ nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))
69
+ ])(x)
70
+
71
+ x = x + f_1
72
+ return x
model/swin_ir.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable, Optional, Iterable
3
+
4
+ import numpy as np
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import flax.linen as nn
8
+ from jaxtyping import Array
9
+
10
+
11
+ def trunc_normal(mean=0., std=1., a=-2., b=2., dtype=jnp.float32) -> Callable:
12
+ """Truncated normal initialization function"""
13
+
14
+ def init(key, shape, dtype=dtype) -> Array:
15
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/weight_init.py
16
+ def norm_cdf(x):
17
+ # Computes standard normal cumulative distribution function
18
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
19
+
20
+ l = norm_cdf((a - mean) / std)
21
+ u = norm_cdf((b - mean) / std)
22
+ out = jax.random.uniform(key, shape, dtype=dtype, minval=2 * l - 1, maxval=2 * u - 1)
23
+ out = jax.scipy.special.erfinv(out) * std * math.sqrt(2.) + mean
24
+ return jnp.clip(out, a, b)
25
+
26
+ return init
27
+
28
+
29
+ def Dense(features, use_bias=True, kernel_init=trunc_normal(std=.02), bias_init=nn.initializers.zeros):
30
+ return nn.Dense(features, use_bias=use_bias, kernel_init=kernel_init, bias_init=bias_init)
31
+
32
+
33
+ def LayerNorm():
34
+ """torch LayerNorm uses larger epsilon by default"""
35
+ return nn.LayerNorm(epsilon=1e-05)
36
+
37
+
38
+ class Mlp(nn.Module):
39
+
40
+ in_features: int
41
+ hidden_features: int = None
42
+ out_features: int = None
43
+ act_layer: Callable = nn.gelu
44
+ drop: float = 0.0
45
+
46
+ @nn.compact
47
+ def __call__(self, x, training: bool):
48
+ x = nn.Dense(self.hidden_features or self.in_features)(x)
49
+ x = self.act_layer(x)
50
+ x = nn.Dropout(self.drop, deterministic=not training)(x)
51
+ x = nn.Dense(self.out_features or self.in_features)(x)
52
+ x = nn.Dropout(self.drop, deterministic=not training)(x)
53
+ return x
54
+
55
+
56
+ def window_partition(x, window_size: int):
57
+ """
58
+ Args:
59
+ x: (B, H, W, C)
60
+ window_size (int): window size
61
+
62
+ Returns:
63
+ windows: (num_windows*B, window_size, window_size, C)
64
+ """
65
+ B, H, W, C = x.shape
66
+ x = x.reshape((B, H // window_size, window_size, W // window_size, window_size, C))
67
+ windows = x.transpose((0, 1, 3, 2, 4, 5)).reshape((-1, window_size, window_size, C))
68
+ return windows
69
+
70
+
71
+ def window_reverse(windows, window_size: int, H: int, W: int):
72
+ """
73
+ Args:
74
+ windows: (num_windows*B, window_size, window_size, C)
75
+ window_size (int): Window size
76
+ H (int): Height of image
77
+ W (int): Width of image
78
+
79
+ Returns:
80
+ x: (B, H, W, C)
81
+ """
82
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
83
+ x = windows.reshape((B, H // window_size, W // window_size, window_size, window_size, -1))
84
+ x = x.transpose((0, 1, 3, 2, 4, 5)).reshape((B, H, W, -1))
85
+ return x
86
+
87
+
88
+ class DropPath(nn.Module):
89
+ """
90
+ Implementation referred from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
91
+ """
92
+
93
+ dropout_prob: float = 0.1
94
+ deterministic: Optional[bool] = None
95
+
96
+ @nn.compact
97
+ def __call__(self, input, training):
98
+ if not training:
99
+ return input
100
+ keep_prob = 1 - self.dropout_prob
101
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1)
102
+ rng = self.make_rng("dropout")
103
+ random_tensor = keep_prob + jax.random.uniform(rng, shape)
104
+ random_tensor = jnp.floor(random_tensor)
105
+ return jnp.divide(input, keep_prob) * random_tensor
106
+
107
+
108
+ class WindowAttention(nn.Module):
109
+ dim: int
110
+ window_size: Iterable[int]
111
+ num_heads: int
112
+ qkv_bias: bool = True
113
+ qk_scale: Optional[float] = None
114
+ att_drop: float = 0.0
115
+ proj_drop: float = 0.0
116
+
117
+ def make_rel_pos_index(self):
118
+ h_indices = np.arange(0, self.window_size[0])
119
+ w_indices = np.arange(0, self.window_size[1])
120
+ indices = np.stack(np.meshgrid(w_indices, h_indices, indexing="ij"))
121
+ flatten_indices = np.reshape(indices, (2, -1))
122
+ relative_indices = flatten_indices[:, :, None] - flatten_indices[:, None, :]
123
+ relative_indices = np.transpose(relative_indices, (1, 2, 0))
124
+ relative_indices[:, :, 0] += self.window_size[0] - 1
125
+ relative_indices[:, :, 1] += self.window_size[1] - 1
126
+ relative_indices[:, :, 0] *= 2 * self.window_size[1] - 1
127
+ relative_pos_index = np.sum(relative_indices, -1)
128
+ return relative_pos_index
129
+
130
+ @nn.compact
131
+ def __call__(self, inputs, mask, training):
132
+ rpbt = self.param(
133
+ "relative_position_bias_table",
134
+ trunc_normal(std=.02),
135
+ (
136
+ (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
137
+ self.num_heads,
138
+ ),
139
+ )
140
+
141
+ #relative_pos_index = self.variable(
142
+ # "variables", "relative_position_index", self.get_rel_pos_index
143
+ #)
144
+
145
+ batch, n, channels = inputs.shape
146
+ qkv = nn.Dense(self.dim * 3, use_bias=self.qkv_bias, name="qkv")(inputs)
147
+ qkv = qkv.reshape(batch, n, 3, self.num_heads, channels // self.num_heads)
148
+ qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
149
+ q, k, v = qkv[0], qkv[1], qkv[2]
150
+
151
+ scale = self.qk_scale or (self.dim // self.num_heads) ** -0.5
152
+ q = q * scale
153
+ att = q @ jnp.swapaxes(k, -2, -1)
154
+
155
+ rel_pos_bias = jnp.reshape(
156
+ rpbt[np.reshape(self.make_rel_pos_index(), (-1))],
157
+ (
158
+ self.window_size[0] * self.window_size[1],
159
+ self.window_size[0] * self.window_size[1],
160
+ -1,
161
+ ),
162
+ )
163
+ rel_pos_bias = jnp.transpose(rel_pos_bias, (2, 0, 1))
164
+ att += jnp.expand_dims(rel_pos_bias, 0)
165
+
166
+ if mask is not None:
167
+ att = jnp.reshape(
168
+ att, (batch // mask.shape[0], mask.shape[0], self.num_heads, n, n)
169
+ )
170
+ att = att + jnp.expand_dims(jnp.expand_dims(mask, 1), 0)
171
+ att = jnp.reshape(att, (-1, self.num_heads, n, n))
172
+ att = jax.nn.softmax(att)
173
+
174
+ else:
175
+ att = jax.nn.softmax(att)
176
+
177
+ att = nn.Dropout(self.att_drop)(att, deterministic=not training)
178
+
179
+ x = jnp.reshape(jnp.swapaxes(att @ v, 1, 2), (batch, n, channels))
180
+ x = nn.Dense(self.dim, name="proj")(x)
181
+ x = nn.Dropout(self.proj_drop)(x, deterministic=not training)
182
+ return x
183
+
184
+
185
+ class SwinTransformerBlock(nn.Module):
186
+
187
+ dim: int
188
+ input_resolution: tuple[int]
189
+ num_heads: int
190
+ window_size: int = 7
191
+ shift_size: int = 0
192
+ mlp_ratio: float = 4.
193
+ qkv_bias: bool = True
194
+ qk_scale: Optional[float] = None
195
+ drop: float = 0.
196
+ attn_drop: float = 0.
197
+ drop_path: float = 0.
198
+ act_layer: Callable = nn.activation.gelu
199
+ norm_layer: Callable = LayerNorm
200
+
201
+ @staticmethod
202
+ def make_att_mask(shift_size, window_size, height, width):
203
+ if shift_size > 0:
204
+ mask = jnp.zeros([1, height, width, 1])
205
+ h_slices = (
206
+ slice(0, -window_size),
207
+ slice(-window_size, -shift_size),
208
+ slice(-shift_size, None),
209
+ )
210
+ w_slices = (
211
+ slice(0, -window_size),
212
+ slice(-window_size, -shift_size),
213
+ slice(-shift_size, None),
214
+ )
215
+
216
+ count = 0
217
+ for h in h_slices:
218
+ for w in w_slices:
219
+ mask = mask.at[:, h, w, :].set(count)
220
+ count += 1
221
+
222
+ mask_windows = window_partition(mask, window_size)
223
+ mask_windows = jnp.reshape(mask_windows, (-1, window_size * window_size))
224
+ att_mask = jnp.expand_dims(mask_windows, 1) - jnp.expand_dims(mask_windows, 2)
225
+ att_mask = jnp.where(att_mask != 0.0, float(-100.0), att_mask)
226
+ att_mask = jnp.where(att_mask == 0.0, float(0.0), att_mask)
227
+ else:
228
+ att_mask = None
229
+
230
+ return att_mask
231
+
232
+ @nn.compact
233
+ def __call__(self, x, x_size, training):
234
+ H, W = x_size
235
+ B, L, C = x.shape
236
+
237
+ if min(self.input_resolution) <= self.window_size:
238
+ # if window size is larger than input resolution, we don't partition windows
239
+ self.shift_size = 0
240
+ self.window_size = min(self.input_resolution)
241
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
242
+
243
+ shortcut = x
244
+ x = self.norm_layer()(x)
245
+ x = x.reshape((B, H, W, C))
246
+
247
+ # cyclic shift
248
+ if self.shift_size > 0:
249
+ shifted_x = jnp.roll(x, (-self.shift_size, -self.shift_size), axis=(1, 2))
250
+ else:
251
+ shifted_x = x
252
+
253
+ # partition windows
254
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
255
+ x_windows = x_windows.reshape((-1, self.window_size * self.window_size, C)) # nW*B, window_size*window_size, C
256
+
257
+ #attn_mask = self.variable(
258
+ # "variables",
259
+ # "attn_mask",
260
+ # self.get_att_mask,
261
+ # self.shift_size,
262
+ # self.window_size,
263
+ # self.input_resolution[0],
264
+ # self.input_resolution[1]
265
+ #)
266
+
267
+ attn_mask = self.make_att_mask(self.shift_size, self.window_size, *self.input_resolution)
268
+
269
+ attn = WindowAttention(self.dim, (self.window_size, self.window_size), self.num_heads,
270
+ self.qkv_bias, self.qk_scale, self.attn_drop, self.drop)
271
+ if self.input_resolution == x_size:
272
+ attn_windows = attn(x_windows, attn_mask, training) # nW*B, window_size*window_size, C
273
+ else:
274
+ # test time
275
+ assert not training
276
+ test_mask = self.make_att_mask(self.shift_size, self.window_size, *x_size)
277
+ attn_windows = attn(x_windows, test_mask, training=False)
278
+
279
+ # merge windows
280
+ attn_windows = attn_windows.reshape((-1, self.window_size, self.window_size, C))
281
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
282
+
283
+ # reverse cyclic shift
284
+ if self.shift_size > 0:
285
+ x = jnp.roll(shifted_x, (self.shift_size, self.shift_size), axis=(1, 2))
286
+ else:
287
+ x = shifted_x
288
+
289
+ x = x.reshape((B, H * W, C))
290
+
291
+ # FFN
292
+ x = shortcut + DropPath(self.drop_path)(x, training)
293
+
294
+ norm = self.norm_layer()(x)
295
+ mlp = Mlp(in_features=self.dim, hidden_features=int(self.dim * self.mlp_ratio),
296
+ act_layer=self.act_layer, drop=self.drop)(norm, training)
297
+ x = x + DropPath(self.drop_path)(mlp, training)
298
+
299
+ return x
300
+
301
+
302
+ class PatchMerging(nn.Module):
303
+ inp_res: Iterable[int]
304
+ dim: int
305
+ norm_layer: Callable = LayerNorm
306
+
307
+ @nn.compact
308
+ def __call__(self, inputs):
309
+ batch, n, channels = inputs.shape
310
+ height, width = self.inp_res[0], self.inp_res[1]
311
+ x = jnp.reshape(inputs, (batch, height, width, channels))
312
+
313
+ x0 = x[:, 0::2, 0::2, :]
314
+ x1 = x[:, 1::2, 0::2, :]
315
+ x2 = x[:, 0::2, 1::2, :]
316
+ x3 = x[:, 1::2, 1::2, :]
317
+
318
+ x = jnp.concatenate([x0, x1, x2, x3], axis=-1)
319
+ x = jnp.reshape(x, (batch, -1, 4 * channels))
320
+ x = self.norm_layer()(x)
321
+ x = nn.Dense(2 * self.dim, use_bias=False)(x)
322
+ return x
323
+
324
+
325
+ class BasicLayer(nn.Module):
326
+
327
+ dim: int
328
+ input_resolution: int
329
+ depth: int
330
+ num_heads: int
331
+ window_size: int
332
+ mlp_ratio: float = 4.
333
+ qkv_bias: bool = True
334
+ qk_scale: Optional[float] = None
335
+ drop: float = 0.
336
+ attn_drop: float = 0.
337
+ drop_path: float = 0.
338
+ norm_layer: Callable = LayerNorm
339
+ downsample: Optional[Callable] = None
340
+
341
+ @nn.compact
342
+ def __call__(self, x, x_size, training):
343
+ for i in range(self.depth):
344
+ x = SwinTransformerBlock(
345
+ self.dim,
346
+ self.input_resolution,
347
+ self.num_heads,
348
+ self.window_size,
349
+ 0 if (i % 2 == 0) else self.window_size // 2,
350
+ self.mlp_ratio,
351
+ self.qkv_bias,
352
+ self.qk_scale,
353
+ self.drop,
354
+ self.attn_drop,
355
+ self.drop_path[i] if isinstance(self.drop_path, (list, tuple)) else self.drop_path,
356
+ norm_layer=self.norm_layer
357
+ )(x, x_size, training)
358
+
359
+ if self.downsample is not None:
360
+ x = self.downsample(self.input_resolution, dim=self.dim, norm_layer=self.norm_layer)(x)
361
+
362
+ return x
363
+
364
+
365
+ class RSTB(nn.Module):
366
+
367
+ dim: int
368
+ input_resolution: int
369
+ depth: int
370
+ num_heads: int
371
+ window_size: int
372
+ mlp_ratio: float = 4.
373
+ qkv_bias: bool = True
374
+ qk_scale: Optional[float] = None
375
+ drop: float = 0.
376
+ attn_drop: float = 0.
377
+ drop_path: float = 0.
378
+ norm_layer: Callable = LayerNorm
379
+ downsample: Optional[Callable] = None
380
+ img_size: int = 224,
381
+ patch_size: int = 4,
382
+ resi_connection: str = '1conv'
383
+
384
+ @nn.compact
385
+ def __call__(self, x, x_size, training):
386
+ res = x
387
+ x = BasicLayer(dim=self.dim,
388
+ input_resolution=self.input_resolution,
389
+ depth=self.depth,
390
+ num_heads=self.num_heads,
391
+ window_size=self.window_size,
392
+ mlp_ratio=self.mlp_ratio,
393
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
394
+ drop=self.drop, attn_drop=self.attn_drop,
395
+ drop_path=self.drop_path,
396
+ norm_layer=self.norm_layer,
397
+ downsample=self.downsample)(x, x_size, training)
398
+
399
+ x = PatchUnEmbed(embed_dim=self.dim)(x, x_size)
400
+
401
+ # resi_connection == '1conv':
402
+ x = nn.Conv(self.dim, (3, 3))(x)
403
+
404
+ x = PatchEmbed()(x)
405
+
406
+ return x + res
407
+
408
+
409
+ class PatchEmbed(nn.Module):
410
+ norm_layer: Optional[Callable] = None
411
+
412
+ @nn.compact
413
+ def __call__(self, x):
414
+ x = x.reshape((x.shape[0], -1, x.shape[-1])) # B Ph Pw C -> B Ph*Pw C
415
+ if self.norm_layer is not None:
416
+ x = self.norm_layer()(x)
417
+ return x
418
+
419
+
420
+ class PatchUnEmbed(nn.Module):
421
+ embed_dim: int = 96
422
+
423
+ @nn.compact
424
+ def __call__(self, x, x_size):
425
+ B, HW, C = x.shape
426
+ x = x.reshape((B, x_size[0], x_size[1], self.embed_dim))
427
+ return x
428
+
429
+
430
+ class SwinIR(nn.Module):
431
+ r""" SwinIR JAX implementation
432
+ Args:
433
+ img_size (int | tuple(int)): Input image size. Default 64
434
+ patch_size (int | tuple(int)): Patch size. Default: 1
435
+ in_chans (int): Number of input image channels. Default: 3
436
+ embed_dim (int): Patch embedding dimension. Default: 96
437
+ depths (tuple(int)): Depth of each Swin Transformer layer.
438
+ num_heads (tuple(int)): Number of attention heads in different layers.
439
+ window_size (int): Window size. Default: 7
440
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
441
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
442
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
443
+ drop_rate (float): Dropout rate. Default: 0
444
+ attn_drop_rate (float): Attention dropout rate. Default: 0
445
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
446
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
447
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
448
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
449
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
450
+ img_range: Image range. 1. or 25I think5.
451
+ """
452
+
453
+ img_size: int = 48
454
+ patch_size: int = 1
455
+ in_chans: int = 3
456
+ embed_dim: int = 180
457
+ depths: tuple = (6, 6, 6, 6, 6, 6)
458
+ num_heads: tuple = (6, 6, 6, 6, 6, 6)
459
+ window_size: int = 8
460
+ mlp_ratio: float = 2.
461
+ qkv_bias: bool = True
462
+ qk_scale: Optional[float] = None
463
+ drop_rate: float = 0.
464
+ attn_drop_rate: float = 0.
465
+ drop_path_rate: float = 0.1
466
+ norm_layer: Callable = LayerNorm
467
+ ape: bool = False
468
+ patch_norm: bool = True
469
+ upscale: int = 2
470
+ img_range: float = 1.
471
+ num_feat: int = 64
472
+
473
+ def pad(self, x):
474
+ _, h, w, _ = x.shape
475
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
476
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
477
+ x = jnp.pad(x, ((0, 0), (0, mod_pad_h), (0, mod_pad_w), (0, 0)), 'reflect')
478
+ return x
479
+
480
+ @nn.compact
481
+ def __call__(self, x, training):
482
+ _, h_before, w_before, _ = x.shape
483
+ x = self.pad(x)
484
+ _, h, w, _ = x.shape
485
+ patches_resolution = [self.img_size // self.patch_size] * 2
486
+ num_patches = patches_resolution[0] * patches_resolution[1]
487
+
488
+ # conv_first
489
+ x = nn.Conv(self.embed_dim, (3, 3))(x)
490
+ res = x
491
+
492
+ # feature extraction
493
+ x_size = (h, w)
494
+ x = PatchEmbed(self.norm_layer if self.patch_norm else None)(x)
495
+
496
+ if self.ape:
497
+ absolute_pos_embed = \
498
+ self.param('ape', trunc_normal(std=.02), (1, num_patches, self.embed_dim))
499
+ x = x + absolute_pos_embed
500
+
501
+ x = nn.Dropout(self.drop_rate, deterministic=not training)(x)
502
+
503
+ dpr = [x.item() for x in np.linspace(0, self.drop_path_rate, sum(self.depths))]
504
+ for i_layer in range(len(self.depths)):
505
+ x = RSTB(
506
+ dim=self.embed_dim,
507
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
508
+ depth=self.depths[i_layer],
509
+ num_heads=self.num_heads[i_layer],
510
+ window_size=self.window_size,
511
+ mlp_ratio=self.mlp_ratio,
512
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
513
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
514
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
515
+ norm_layer=self.norm_layer,
516
+ downsample=None,
517
+ img_size=self.img_size,
518
+ patch_size=self.patch_size)(x, x_size, training)
519
+
520
+ x = self.norm_layer()(x) # B L C
521
+ x = PatchUnEmbed(self.embed_dim)(x, x_size)
522
+
523
+ # conv_after_body
524
+ x = nn.Conv(self.embed_dim, (3, 3))(x)
525
+ x = x + res
526
+
527
+ # conv_before_upsample
528
+ x = nn.activation.leaky_relu(nn.Conv(self.num_feat, (3, 3))(x))
529
+
530
+ # revert padding
531
+ x = x[:, :-(h - h_before) or None, :-(w - w_before) or None]
532
+ return x
model/tail.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flax.linen as nn
2
+
3
+ from .convnext import ConvNeXt
4
+ from .swin_ir import SwinIR
5
+
6
+
7
+ def build_tail(size: str):
8
+ """ Convenience function to build the three tails described in the paper. """
9
+ if size == 'air':
10
+ return lambda x, _: x
11
+ elif size == 'plus':
12
+ blocks = [(64, 3, True)] * 6 + [(96, 3, True)] * 7 + [(128, 3, True)] * 3
13
+ return ConvNeXt(blocks)
14
+ elif size == 'pro':
15
+ return SwinIR(depths=[7, 6], num_heads=[6, 6])
16
+ else:
17
+ raise NotImplementedError('size: ' + size)
18
+
model/thera.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import jax
4
+ from flax.core import unfreeze, freeze
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+ from jaxtyping import Array, ArrayLike, PyTree
8
+
9
+ from .edsr import EDSR
10
+ from .rdn import RDN
11
+ from .hyper import Hypernetwork
12
+ from .tail import build_tail
13
+ from .init import uniform_between, linear_up
14
+ from utils import make_grid, interpolate_grid, repeat_vmap
15
+
16
+
17
+ class Thermal(nn.Module):
18
+ w0_scale: float = 1.
19
+
20
+ @nn.compact
21
+ def __call__(self, x: ArrayLike, t, norm, k) -> Array:
22
+ phase = self.param('phase', nn.initializers.uniform(.5), x.shape[-1:])
23
+ return jnp.sin(self.w0_scale * x + phase) * jnp.exp(-(self.w0_scale * norm)**2 * k * t)
24
+
25
+
26
+ class TheraField(nn.Module):
27
+ dim_hidden: int
28
+ dim_out: int
29
+ w0: float = 1.
30
+ c: float = 6.
31
+
32
+ @nn.compact
33
+ def __call__(self, x: ArrayLike, t: ArrayLike, k: ArrayLike, components: ArrayLike) -> Array:
34
+ # coordinate projection according to shared components ("first layer")
35
+ x = x @ components
36
+
37
+ # thermal activations
38
+ norm = jnp.linalg.norm(components, axis=-2)
39
+ x = Thermal(self.w0)(x, t, norm, k)
40
+
41
+ # linear projection from hidden to output space ("second layer")
42
+ w_std = math.sqrt(self.c / self.dim_hidden) / self.w0
43
+ dense_init_fn = uniform_between(-w_std, w_std)
44
+ x = nn.Dense(self.dim_out, kernel_init=dense_init_fn, use_bias=False)(x)
45
+
46
+ return x
47
+
48
+
49
+ class Thera:
50
+
51
+ def __init__(
52
+ self,
53
+ hidden_dim: int,
54
+ out_dim: int,
55
+ backbone: nn.Module,
56
+ tail: nn.Module,
57
+ k_init: float = None,
58
+ components_init_scale: float = None
59
+ ):
60
+ self.hidden_dim = hidden_dim
61
+ self.k_init = k_init
62
+ self.components_init_scale = components_init_scale
63
+
64
+ # single TheraField object whose `apply` method is used for all grid cells
65
+ self.field = TheraField(hidden_dim, out_dim)
66
+
67
+ # infer output size of the hypernetwork from a sample pass through the field;
68
+ # key doesnt matter as field params are only used for size inference
69
+ sample_params = self.field.init(jax.random.PRNGKey(0),
70
+ jnp.zeros((2,)), 0., 0., jnp.zeros((2, hidden_dim)))
71
+ sample_params_flat, tree_def = jax.tree_util.tree_flatten(sample_params)
72
+ param_shapes = [p.shape for p in sample_params_flat]
73
+
74
+ self.hypernet = Hypernetwork(backbone, tail, param_shapes, tree_def)
75
+
76
+ def init(self, key, sample_source) -> PyTree:
77
+ keys = jax.random.split(key, 2)
78
+ sample_coords = jnp.zeros(sample_source.shape[:-1] + (2,))
79
+ params = unfreeze(self.hypernet.init(keys[0], sample_source, sample_coords))
80
+
81
+ params['params']['k'] = jnp.array(self.k_init)
82
+ params['params']['components'] = \
83
+ linear_up(self.components_init_scale)(keys[1], (2, self.hidden_dim))
84
+
85
+ return freeze(params)
86
+
87
+ def apply_encoder(self, params: PyTree, source: ArrayLike, **kwargs) -> Array:
88
+ """
89
+ Performs a forward pass through the hypernetwork to obtain an encoding.
90
+ """
91
+ return self.hypernet.apply(
92
+ params, source, method=self.hypernet.get_encoding, **kwargs)
93
+
94
+ def apply_decoder(
95
+ self,
96
+ params: PyTree,
97
+ encoding: ArrayLike,
98
+ coords: ArrayLike,
99
+ t: ArrayLike,
100
+ return_jac: bool = False
101
+ ) -> Array | tuple[Array, Array]:
102
+ """
103
+ Performs a forward prediction through a grid of HxW Thera fields,
104
+ informed by `encoding`, at spatial and temporal coordinates
105
+ `coords` and `t`, respectively.
106
+ args:
107
+ params: Field parameters, shape (B, H, W, N)
108
+ encoding: Encoding tensor, shape (B, H, W, C)
109
+ coords: Spatial coordinates in [-0.5, 0.5], shape (B, H, W, 2)
110
+ t: Temporal coordinates, shape (B, 1)
111
+ """
112
+ phi_params: PyTree = self.hypernet.apply(
113
+ params, encoding, coords, method=self.hypernet.get_params_at_coords)
114
+
115
+ # create local coordinate systems
116
+ source_grid = jnp.asarray(make_grid(encoding.shape[-3:-1]))
117
+ source_coords = jnp.tile(source_grid, (encoding.shape[0], 1, 1, 1))
118
+ interp_coords = interpolate_grid(coords, source_coords)
119
+ rel_coords = (coords - interp_coords)
120
+ rel_coords = rel_coords.at[..., 0].set(rel_coords[..., 0] * encoding.shape[-3])
121
+ rel_coords = rel_coords.at[..., 1].set(rel_coords[..., 1] * encoding.shape[-2])
122
+
123
+ # three maps over params, coords; one over t; dont map k and components
124
+ in_axes = [(0, 0, None, None, None), (0, 0, None, None, None), (0, 0, 0, None, None)]
125
+ apply_field = repeat_vmap(self.field.apply, in_axes)
126
+ out = apply_field(phi_params, rel_coords, t, params['params']['k'],
127
+ params['params']['components'])
128
+
129
+ if return_jac:
130
+ apply_jac = repeat_vmap(jax.jacrev(self.field.apply, argnums=1), in_axes)
131
+ jac = apply_jac(phi_params, rel_coords, jnp.zeros_like(t), params['params']['k'],
132
+ params['params']['components'])
133
+ return out, jac
134
+
135
+ return out
136
+
137
+ def apply(
138
+ self,
139
+ params: ArrayLike,
140
+ source: ArrayLike,
141
+ coords: ArrayLike,
142
+ t: ArrayLike,
143
+ return_jac: bool = False,
144
+ **kwargs
145
+ ) -> Array:
146
+ """
147
+ Performs a forward pass through the Thera model.
148
+ """
149
+ encoding = self.apply_encoder(params, source, **kwargs)
150
+ out = self.apply_decoder(params, encoding, coords, t, return_jac=return_jac)
151
+ return out
152
+
153
+
154
+ def build_thera(
155
+ out_dim: int,
156
+ backbone: str,
157
+ size: str,
158
+ k_init: float = None,
159
+ components_init_scale: float = None
160
+ ):
161
+ """
162
+ Convenience function for building the three Thera sizes described in the paper.
163
+ """
164
+ hidden_dim = 32 if size == 'air' else 512
165
+
166
+ if backbone == 'edsr-baseline':
167
+ backbone_module = EDSR(None, num_blocks=16, num_feats=64)
168
+ elif backbone == 'rdn':
169
+ backbone_module = RDN()
170
+ else:
171
+ raise NotImplementedError(backbone)
172
+
173
+ tail_module = build_tail(size)
174
+
175
+ return Thera(hidden_dim, out_dim, backbone_module, tail_module, k_init, components_init_scale)
requirements.txt CHANGED
@@ -1,6 +1,37 @@
1
- gradio
2
- torch
3
- peft
4
- transformers
 
5
  diffusers
6
- torchvision
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
2
+
3
+ ConfigArgParse==1.7
4
+ Pillow==10.0.0
5
+ chex==0.1.7
6
  diffusers
7
+ einops==0.6.1
8
+ flax==0.6.10
9
+ flaxmodels==0.1.3
10
+ jax==0.4.11
11
+ jaxlib==0.4.11+cuda11.cudnn86
12
+ jaxtyping==0.2.20
13
+ ml-dtypes==0.1.0
14
+ numpy==1.24.1
15
+ nvidia-cublas-cu11==11.11.3.6
16
+ nvidia-cuda-cupti-cu11==11.8.87
17
+ nvidia-cuda-nvcc-cu11==11.8.89
18
+ nvidia-cuda-runtime-cu11==11.8.89
19
+ nvidia-cudnn-cu11==8.9.2.26
20
+ nvidia-cufft-cu11==10.9.0.58
21
+ nvidia-cusolver-cu11==11.4.1.48
22
+ nvidia-cusparse-cu11==11.7.5.86
23
+ opt-einsum==3.3.0
24
+ optax==0.2.0
25
+ orbax-checkpoint==0.2.4
26
+ peft
27
+ scipy==1.10.1
28
+ timm==0.9.6
29
+ torch
30
+ torchvision
31
+ tqdm==4.65.0
32
+ transformers==4.46.3
33
+ wandb
34
+
35
+ gradio==4.44.1
36
+ gradio_imageslider==0.0.20
37
+ spaces
super_resolve.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from argparse import ArgumentParser, Namespace
4
+ import pickle
5
+
6
+ import jax
7
+ from jax import jit
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ from model import build_thera
13
+ from utils import make_grid, interpolate_grid
14
+
15
+ MEAN = jnp.array([.4488, .4371, .4040])
16
+ VAR = jnp.array([.25, .25, .25])
17
+ PATCH_SIZE = 256
18
+
19
+
20
+ def process_single(source, apply_encoder, apply_decoder, params, target_shape):
21
+ t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None]
22
+ coords_nearest = jnp.asarray(make_grid(target_shape)[None])
23
+ source_up = interpolate_grid(coords_nearest, source[None])
24
+ source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None]
25
+
26
+ encoding = apply_encoder(params, source)
27
+ coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) # global sampling coords
28
+ out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32)
29
+
30
+ for h_min in range(0, coords.shape[1], PATCH_SIZE):
31
+ h_max = min(h_min + PATCH_SIZE, coords.shape[1])
32
+ for w_min in range(0, coords.shape[2], PATCH_SIZE):
33
+ # apply decoder with one patch of coordinates
34
+ w_max = min(w_min + PATCH_SIZE, coords.shape[2])
35
+ coords_patch = coords[:, h_min:h_max, w_min:w_max]
36
+ out_patch = apply_decoder(params, encoding, coords_patch, t)
37
+ out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch)
38
+
39
+ out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
40
+ out += source_up
41
+ return out
42
+
43
+
44
+ def process(source, model, params, target_shape, do_ensemble=True):
45
+ apply_encoder = jit(model.apply_encoder)
46
+ apply_decoder = jit(model.apply_decoder)
47
+
48
+ outs = []
49
+ for i_rot in range(4 if do_ensemble else 1):
50
+ source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2))
51
+ target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape
52
+ out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_)
53
+ outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3)))
54
+
55
+ out = jnp.stack(outs).mean(0).clip(0., 1.)
56
+ return jnp.rint(out[0] * 255).astype(jnp.uint8)
57
+
58
+
59
+ def main(args: Namespace):
60
+ source = np.asarray(Image.open(args.in_file)) / 255.
61
+
62
+ if args.scale is not None:
63
+ if args.size is not None:
64
+ raise ValueError('Cannot specify both size and scale')
65
+ target_shape = (
66
+ round(source.shape[0] * args.scale),
67
+ round(source.shape[1] * args.scale),
68
+ )
69
+ elif args.size is not None:
70
+ target_shape = args.size
71
+ else:
72
+ raise ValueError('Must specify either size or scale')
73
+
74
+ with open(args.checkpoint, 'rb') as fh:
75
+ check = pickle.load(fh)
76
+ params, backbone, size = check['model'], check['backbone'], check['size']
77
+
78
+ model = build_thera(3, backbone, size)
79
+
80
+ out = process(source, model, params, target_shape, not args.no_ensemble)
81
+
82
+ Image.fromarray(np.asarray(out)).save(args.out_file)
83
+
84
+
85
+ def parse_args() -> Namespace:
86
+ parser = ArgumentParser()
87
+ parser.add_argument('in_file')
88
+ parser.add_argument('out_file')
89
+ parser.add_argument('--scale', type=float, help='Scale factor for super-resolution')
90
+ parser.add_argument('--size', type=int, nargs=2,
91
+ help='Target size (h, w), mutually exclusive with --scale')
92
+ parser.add_argument('--checkpoint', help='Path to checkpoint file')
93
+ parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble')
94
+ return parser.parse_args()
95
+
96
+
97
+ if __name__ == '__main__':
98
+ args = parse_args()
99
+ main(args)
utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import jax
4
+ import numpy as np
5
+
6
+
7
+ def repeat_vmap(fun, in_axes=[0]):
8
+ for axes in in_axes:
9
+ fun = jax.vmap(fun, in_axes=axes)
10
+ return fun
11
+
12
+
13
+ def make_grid(patch_size: int | tuple[int, int]):
14
+ if isinstance(patch_size, int):
15
+ patch_size = (patch_size, patch_size)
16
+ offset_h, offset_w = 1 / (2 * np.array(patch_size))
17
+ space_h = np.linspace(-0.5 + offset_h, 0.5 - offset_h, patch_size[0])
18
+ space_w = np.linspace(-0.5 + offset_w, 0.5 - offset_w, patch_size[1])
19
+ return np.stack(np.meshgrid(space_h, space_w, indexing='ij'), axis=-1) # [h, w]
20
+
21
+
22
+ def interpolate_grid(coords, grid, order=0):
23
+ """
24
+ args:
25
+ coords: Tensor of shape (B, H, W, 2) with coordinates in [-0.5, 0.5]
26
+ grid: Tensor of shape (B, H', W', C)
27
+ returns:
28
+ Tensor of shape (B, H, W, C) with interpolated values
29
+ """
30
+ # convert [-0.5, 0.5] -> [0, size], where pixel centers are expected at
31
+ # [-0.5 + 1 / (2*size), ..., 0.5 - 1 / (2*size)]
32
+ coords = coords.transpose((0, 3, 1, 2))
33
+ coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
34
+ coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
35
+ map_coordinates = partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest')
36
+ return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)