Spaces:
Runtime error
Runtime error
innat
commited on
Commit
•
0f09377
1
Parent(s):
f1deb8a
init
Browse files- .gitignore +132 -0
- README.md +9 -5
- app.py +95 -0
- config.py +30 -0
- examples/daisy.jpg +0 -0
- examples/dandelion.jpg +0 -0
- examples/rose.jpg +0 -0
- examples/sunflower.jpg +0 -0
- examples/tulip.jpg +0 -0
- layers/__init__.py +0 -0
- layers/swin_blocks.py +139 -0
- layers/window_attention.py +111 -0
- models/__init__.py +1 -0
- models/hybrid_model.py +170 -0
- requirements.txt +7 -0
- utils/__init__.py +0 -0
- utils/drop_path.py +31 -0
- utils/model_utils.py +46 -0
- utils/patch.py +80 -0
- utils/swin_window.py +25 -0
- utils/viz_utils.py +64 -0
.gitignore
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
# Pycharm
|
132 |
+
.idea/
|
README.md
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.0.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Demo
|
3 |
+
emoji: 🔥
|
4 |
colorFrom: purple
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.0.15
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
## Visual Interpretation of a Hybrid Model
|
13 |
+
|
14 |
+
Building a hybrid model with *EfficientNet* and *Swin Transformer*, we have tried to inspect the visual interpretations of a CNN and Transformer blocks of a hybrid model (CNN + Swin Transformer) with the GradCAM technique. As a result, it appears that the transformer blocks are capable of globally refining feature activation across the relevant object, as opposed to the CNN, which is more focused on operating locally. However, the approach that will be shown here, is experimental. The workflow probably can generate a more meaningful modeling approach. The model is trained on [tf_flowers](https://www.tensorflow.org/datasets/catalog/tf_flowers) dataset, a multi-class classification problem.
|
15 |
+
|
16 |
+
![]('./Presentation2.png')
|
app.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gdown
|
4 |
+
import gradio as gr
|
5 |
+
import tensorflow as tf
|
6 |
+
|
7 |
+
from config import Parameters
|
8 |
+
from models.hybrid_model import GradientAccumulation
|
9 |
+
from utils.model_utils import *
|
10 |
+
from utils.viz_utils import make_gradcam_heatmap
|
11 |
+
from utils.viz_utils import save_and_display_gradcam
|
12 |
+
|
13 |
+
image_size = Parameters().image_size
|
14 |
+
str_labels = [
|
15 |
+
"daisy",
|
16 |
+
"dandelion",
|
17 |
+
"roses",
|
18 |
+
"sunflowers",
|
19 |
+
"tulips",
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
def get_model():
|
24 |
+
"""Get the model."""
|
25 |
+
model = GradientAccumulation(
|
26 |
+
n_gradients=params.num_grad_accumulation, model_name="HybridModel"
|
27 |
+
)
|
28 |
+
_ = model(tf.ones((1, params.image_size, params.image_size, 3)))[0].shape
|
29 |
+
return model
|
30 |
+
|
31 |
+
|
32 |
+
def get_model_weight(model_id):
|
33 |
+
"""Get the trained weights."""
|
34 |
+
if not os.path.exists("model.h5"):
|
35 |
+
model_weight = gdown.download(id=model_id, quiet=False)
|
36 |
+
else:
|
37 |
+
model_weight = "model.h5"
|
38 |
+
return model_weight
|
39 |
+
|
40 |
+
|
41 |
+
def load_model(model_id):
|
42 |
+
"""Load trained model."""
|
43 |
+
weight = get_model_weight(model_id)
|
44 |
+
model = get_model()
|
45 |
+
model.load_weights(weight)
|
46 |
+
return model
|
47 |
+
|
48 |
+
|
49 |
+
def image_process(image):
|
50 |
+
"""Image preprocess for model input."""
|
51 |
+
image = tf.cast(image, dtype=tf.float32)
|
52 |
+
original_shape = image.shape
|
53 |
+
image = tf.image.resize(image, [image_size, image_size])
|
54 |
+
image = image[tf.newaxis, ...]
|
55 |
+
return image, original_shape
|
56 |
+
|
57 |
+
|
58 |
+
def predict_fn(image):
|
59 |
+
"""A predict function that will be invoked by gradio."""
|
60 |
+
loaded_model = load_model(model_id="1y6tseN0194T6d-4iIh5wo7RL9ttQERe0")
|
61 |
+
loaded_image, original_shape = image_process(image)
|
62 |
+
|
63 |
+
heatmap_a, heatmap_b, preds = make_gradcam_heatmap(loaded_image, loaded_model)
|
64 |
+
int_label = tf.argmax(preds, axis=-1).numpy()[0]
|
65 |
+
str_label = str_labels[int_label]
|
66 |
+
|
67 |
+
overaly_a = save_and_display_gradcam(
|
68 |
+
loaded_image[0], heatmap_a, image_shape=original_shape[:2]
|
69 |
+
)
|
70 |
+
overlay_b = save_and_display_gradcam(
|
71 |
+
loaded_image[0], heatmap_b, image_shape=original_shape[:2]
|
72 |
+
)
|
73 |
+
|
74 |
+
return [f"Predicted: {str_label}", overaly_a, overlay_b]
|
75 |
+
|
76 |
+
|
77 |
+
iface = gr.Interface(
|
78 |
+
fn=predict_fn,
|
79 |
+
inputs=gr.inputs.Image(label="Input Image"),
|
80 |
+
outputs=[
|
81 |
+
gr.outputs.Label(label="Prediction"),
|
82 |
+
gr.inputs.Image(label="CNN GradCAM"),
|
83 |
+
gr.inputs.Image(label="Transformer GradCAM"),
|
84 |
+
],
|
85 |
+
title="Hybrid EfficientNet Swin Transformer Demo",
|
86 |
+
description="The model is trained on tf_flowers dataset.",
|
87 |
+
examples=[
|
88 |
+
["examples/dandelion.jpg"],
|
89 |
+
["examples/sunflower.jpg"],
|
90 |
+
["examples/tulip.jpg"],
|
91 |
+
["examples/daisy.jpg"],
|
92 |
+
["examples/rose.jpg"],
|
93 |
+
],
|
94 |
+
)
|
95 |
+
iface.launch()
|
config.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
|
4 |
+
|
5 |
+
class Parameters:
|
6 |
+
# data level
|
7 |
+
image_count = 3670
|
8 |
+
image_size = 384
|
9 |
+
batch_size = 12
|
10 |
+
num_grad_accumulation = 8
|
11 |
+
label_smooth = 0.05
|
12 |
+
class_number = 5
|
13 |
+
val_split = 0.2
|
14 |
+
autotune = tf.data.AUTOTUNE
|
15 |
+
|
16 |
+
# hparams
|
17 |
+
epochs = 10
|
18 |
+
lr_sched = "cosine_restart"
|
19 |
+
lr_base = 0.016
|
20 |
+
lr_min = 0
|
21 |
+
lr_decay_epoch = 2.4
|
22 |
+
lr_warmup_epoch = 5
|
23 |
+
lr_decay_factor = 0.97
|
24 |
+
|
25 |
+
scaled_lr = lr_base * (batch_size / 256.0)
|
26 |
+
scaled_lr_min = lr_min * (batch_size / 256.0)
|
27 |
+
num_validation_sample = int(image_count * val_split)
|
28 |
+
num_training_sample = image_count - num_validation_sample
|
29 |
+
train_step = int(np.ceil(num_training_sample / float(batch_size)))
|
30 |
+
total_steps = train_step * epochs
|
examples/daisy.jpg
ADDED
examples/dandelion.jpg
ADDED
examples/rose.jpg
ADDED
examples/sunflower.jpg
ADDED
examples/tulip.jpg
ADDED
layers/__init__.py
ADDED
File without changes
|
layers/swin_blocks.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
from jax import numpy as jnp
|
3 |
+
except ModuleNotFoundError:
|
4 |
+
# jax doesn't support windows os yet.
|
5 |
+
import numpy as jnp
|
6 |
+
|
7 |
+
import tensorflow as tf
|
8 |
+
from tensorflow import keras
|
9 |
+
from tensorflow.keras import layers
|
10 |
+
|
11 |
+
from layers.window_attention import WindowAttention
|
12 |
+
from utils.drop_path import DropPath
|
13 |
+
from utils.swin_window import window_partition
|
14 |
+
from utils.swin_window import window_reverse
|
15 |
+
|
16 |
+
|
17 |
+
class SwinTransformer(layers.Layer):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
dim,
|
21 |
+
num_patch,
|
22 |
+
num_heads,
|
23 |
+
window_size=7,
|
24 |
+
shift_size=0,
|
25 |
+
num_mlp=1024,
|
26 |
+
qkv_bias=True,
|
27 |
+
dropout_rate=0.0,
|
28 |
+
**kwargs,
|
29 |
+
):
|
30 |
+
super(SwinTransformer, self).__init__(**kwargs)
|
31 |
+
|
32 |
+
self.dim = dim
|
33 |
+
self.num_patch = num_patch
|
34 |
+
self.num_heads = num_heads
|
35 |
+
self.window_size = window_size
|
36 |
+
self.shift_size = shift_size
|
37 |
+
self.num_mlp = num_mlp
|
38 |
+
|
39 |
+
self.norm1 = layers.LayerNormalization(epsilon=1e-5)
|
40 |
+
self.attn = WindowAttention(
|
41 |
+
dim,
|
42 |
+
window_size=(self.window_size, self.window_size),
|
43 |
+
num_heads=num_heads,
|
44 |
+
qkv_bias=qkv_bias,
|
45 |
+
dropout_rate=dropout_rate,
|
46 |
+
)
|
47 |
+
self.drop_path = DropPath(dropout_rate) if dropout_rate > 0.0 else tf.identity
|
48 |
+
self.norm2 = layers.LayerNormalization(epsilon=1e-5)
|
49 |
+
|
50 |
+
self.mlp = keras.Sequential(
|
51 |
+
[
|
52 |
+
layers.Dense(num_mlp),
|
53 |
+
layers.Activation(keras.activations.gelu),
|
54 |
+
layers.Dropout(dropout_rate),
|
55 |
+
layers.Dense(dim),
|
56 |
+
layers.Dropout(dropout_rate),
|
57 |
+
]
|
58 |
+
)
|
59 |
+
|
60 |
+
if min(self.num_patch) < self.window_size:
|
61 |
+
self.shift_size = 0
|
62 |
+
self.window_size = min(self.num_patch)
|
63 |
+
|
64 |
+
def build(self, input_shape):
|
65 |
+
if self.shift_size == 0:
|
66 |
+
self.attn_mask = None
|
67 |
+
else:
|
68 |
+
height, width = self.num_patch
|
69 |
+
h_slices = (
|
70 |
+
slice(0, -self.window_size),
|
71 |
+
slice(-self.window_size, -self.shift_size),
|
72 |
+
slice(-self.shift_size, None),
|
73 |
+
)
|
74 |
+
w_slices = (
|
75 |
+
slice(0, -self.window_size),
|
76 |
+
slice(-self.window_size, -self.shift_size),
|
77 |
+
slice(-self.shift_size, None),
|
78 |
+
)
|
79 |
+
mask_array = jnp.zeros((1, height, width, 1))
|
80 |
+
count = 0
|
81 |
+
for h in h_slices:
|
82 |
+
for w in w_slices:
|
83 |
+
mask_array[:, h, w, :] = count
|
84 |
+
count += 1
|
85 |
+
mask_array = tf.convert_to_tensor(mask_array)
|
86 |
+
|
87 |
+
# mask array to windows
|
88 |
+
mask_windows = window_partition(mask_array, self.window_size)
|
89 |
+
mask_windows = tf.reshape(
|
90 |
+
mask_windows, shape=[-1, self.window_size * self.window_size]
|
91 |
+
)
|
92 |
+
attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(
|
93 |
+
mask_windows, axis=2
|
94 |
+
)
|
95 |
+
attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
|
96 |
+
attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
|
97 |
+
self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False)
|
98 |
+
|
99 |
+
def call(self, x):
|
100 |
+
height, width = self.num_patch
|
101 |
+
_, num_patches_before, channels = x.shape
|
102 |
+
x_skip = x
|
103 |
+
x = self.norm1(x)
|
104 |
+
x = tf.reshape(x, shape=(-1, height, width, channels))
|
105 |
+
if self.shift_size > 0:
|
106 |
+
shifted_x = tf.roll(
|
107 |
+
x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
shifted_x = x
|
111 |
+
|
112 |
+
x_windows = window_partition(shifted_x, self.window_size)
|
113 |
+
x_windows = tf.reshape(
|
114 |
+
x_windows, shape=(-1, self.window_size * self.window_size, channels)
|
115 |
+
)
|
116 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask)
|
117 |
+
|
118 |
+
attn_windows = tf.reshape(
|
119 |
+
attn_windows, shape=(-1, self.window_size, self.window_size, channels)
|
120 |
+
)
|
121 |
+
shifted_x = window_reverse(
|
122 |
+
attn_windows, self.window_size, height, width, channels
|
123 |
+
)
|
124 |
+
if self.shift_size > 0:
|
125 |
+
x = tf.roll(
|
126 |
+
shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
x = shifted_x
|
130 |
+
|
131 |
+
x = tf.reshape(x, shape=(-1, height * width, channels))
|
132 |
+
x = self.drop_path(x)
|
133 |
+
x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32)
|
134 |
+
x_skip = x
|
135 |
+
x = self.norm2(x)
|
136 |
+
x = self.mlp(x)
|
137 |
+
x = self.drop_path(x)
|
138 |
+
x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32)
|
139 |
+
return x
|
layers/window_attention.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import layers
|
3 |
+
|
4 |
+
|
5 |
+
class WindowAttention(layers.Layer):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
dim,
|
9 |
+
window_size,
|
10 |
+
num_heads,
|
11 |
+
qkv_bias=True,
|
12 |
+
dropout_rate=0.0,
|
13 |
+
return_attention_scores=False,
|
14 |
+
**kwargs,
|
15 |
+
):
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
self.dim = dim
|
18 |
+
self.window_size = window_size
|
19 |
+
self.num_heads = num_heads
|
20 |
+
self.scale = (dim // num_heads) ** -0.5
|
21 |
+
self.return_attention_scores = return_attention_scores
|
22 |
+
self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)
|
23 |
+
self.dropout = layers.Dropout(dropout_rate)
|
24 |
+
self.proj = layers.Dense(dim)
|
25 |
+
|
26 |
+
def build(self, input_shape):
|
27 |
+
self.relative_position_bias_table = self.add_weight(
|
28 |
+
shape=(
|
29 |
+
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
|
30 |
+
self.num_heads,
|
31 |
+
),
|
32 |
+
initializer="zeros",
|
33 |
+
trainable=True,
|
34 |
+
name="relative_position_bias_table",
|
35 |
+
)
|
36 |
+
|
37 |
+
self.relative_position_index = self.get_relative_position_index(
|
38 |
+
self.window_size[0], self.window_size[1]
|
39 |
+
)
|
40 |
+
super().build(input_shape)
|
41 |
+
|
42 |
+
def get_relative_position_index(self, window_height, window_width):
|
43 |
+
x_x, y_y = tf.meshgrid(range(window_height), range(window_width))
|
44 |
+
coords = tf.stack([y_y, x_x], axis=0)
|
45 |
+
coords_flatten = tf.reshape(coords, [2, -1])
|
46 |
+
|
47 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
48 |
+
relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0])
|
49 |
+
|
50 |
+
x_x = (relative_coords[:, :, 0] + window_height - 1) * (2 * window_width - 1)
|
51 |
+
y_y = relative_coords[:, :, 1] + window_width - 1
|
52 |
+
relative_coords = tf.stack([x_x, y_y], axis=-1)
|
53 |
+
|
54 |
+
return tf.reduce_sum(relative_coords, axis=-1)
|
55 |
+
|
56 |
+
def call(self, x, mask=None):
|
57 |
+
_, size, channels = x.shape
|
58 |
+
head_dim = channels // self.num_heads
|
59 |
+
x_qkv = self.qkv(x)
|
60 |
+
x_qkv = tf.reshape(x_qkv, shape=(-1, size, 3, self.num_heads, head_dim))
|
61 |
+
x_qkv = tf.transpose(x_qkv, perm=(2, 0, 3, 1, 4))
|
62 |
+
q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
|
63 |
+
q = q * self.scale
|
64 |
+
k = tf.transpose(k, perm=(0, 1, 3, 2))
|
65 |
+
attn = q @ k
|
66 |
+
|
67 |
+
relative_position_bias = tf.gather(
|
68 |
+
self.relative_position_bias_table,
|
69 |
+
self.relative_position_index,
|
70 |
+
axis=0,
|
71 |
+
)
|
72 |
+
relative_position_bias = tf.transpose(relative_position_bias, [2, 0, 1])
|
73 |
+
attn = attn + tf.expand_dims(relative_position_bias, axis=0)
|
74 |
+
|
75 |
+
if mask is not None:
|
76 |
+
nW = mask.get_shape()[0]
|
77 |
+
mask_float = tf.cast(
|
78 |
+
tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32
|
79 |
+
)
|
80 |
+
attn = (
|
81 |
+
tf.reshape(attn, shape=(-1, nW, self.num_heads, size, size))
|
82 |
+
+ mask_float
|
83 |
+
)
|
84 |
+
attn = tf.reshape(attn, shape=(-1, self.num_heads, size, size))
|
85 |
+
attn = tf.nn.softmax(attn, axis=-1)
|
86 |
+
else:
|
87 |
+
attn = tf.nn.softmax(attn, axis=-1)
|
88 |
+
attn = self.dropout(attn)
|
89 |
+
|
90 |
+
x_qkv = attn @ v
|
91 |
+
x_qkv = tf.transpose(x_qkv, perm=(0, 2, 1, 3))
|
92 |
+
x_qkv = tf.reshape(x_qkv, shape=(-1, size, channels))
|
93 |
+
x_qkv = self.proj(x_qkv)
|
94 |
+
x_qkv = self.dropout(x_qkv)
|
95 |
+
|
96 |
+
if self.return_attention_scores:
|
97 |
+
return x_qkv, attn
|
98 |
+
else:
|
99 |
+
return x_qkv
|
100 |
+
|
101 |
+
def get_config(self):
|
102 |
+
config = super().get_config()
|
103 |
+
config.update(
|
104 |
+
{
|
105 |
+
"dim": self.dim,
|
106 |
+
"window_size": self.window_size,
|
107 |
+
"num_heads": self.num_heads,
|
108 |
+
"scale": self.scale,
|
109 |
+
}
|
110 |
+
)
|
111 |
+
return config
|
models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
models/hybrid_model.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow import keras
|
3 |
+
from tensorflow.keras import layers
|
4 |
+
|
5 |
+
from layers.swin_blocks import SwinTransformer
|
6 |
+
from utils.model_utils import *
|
7 |
+
from utils.patch import PatchEmbedding
|
8 |
+
from utils.patch import PatchExtract
|
9 |
+
from utils.patch import PatchMerging
|
10 |
+
|
11 |
+
|
12 |
+
class HybridSwinTransformer(keras.Model):
|
13 |
+
def __init__(self, model_name, **kwargs):
|
14 |
+
super().__init__(name=model_name, **kwargs)
|
15 |
+
# base models
|
16 |
+
base = keras.applications.EfficientNetB0(
|
17 |
+
include_top=False,
|
18 |
+
weights=None,
|
19 |
+
input_tensor=keras.Input((params.image_size, params.image_size, 3)),
|
20 |
+
)
|
21 |
+
|
22 |
+
# base model with compatible output which will be an input of transformer model
|
23 |
+
self.new_base = keras.Model(
|
24 |
+
[base.inputs],
|
25 |
+
[base.get_layer("block6a_expand_activation").output, base.output],
|
26 |
+
name="efficientnet",
|
27 |
+
)
|
28 |
+
|
29 |
+
# stuff of swin transformers
|
30 |
+
self.patch_extract = PatchExtract(patch_size)
|
31 |
+
self.patch_embedds = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)
|
32 |
+
self.patch_merging = PatchMerging(
|
33 |
+
(num_patch_x, num_patch_y), embed_dim=embed_dim
|
34 |
+
)
|
35 |
+
|
36 |
+
# swin blocks containers
|
37 |
+
self.swin_sequences = keras.Sequential(name="swin_blocks")
|
38 |
+
for i in range(shift_size):
|
39 |
+
self.swin_sequences.add(
|
40 |
+
SwinTransformer(
|
41 |
+
dim=embed_dim,
|
42 |
+
num_patch=(num_patch_x, num_patch_y),
|
43 |
+
num_heads=num_heads,
|
44 |
+
window_size=window_size,
|
45 |
+
shift_size=i,
|
46 |
+
num_mlp=num_mlp,
|
47 |
+
qkv_bias=qkv_bias,
|
48 |
+
dropout_rate=dropout_rate,
|
49 |
+
)
|
50 |
+
)
|
51 |
+
|
52 |
+
# swin block's head
|
53 |
+
self.swin_head = keras.Sequential(
|
54 |
+
[
|
55 |
+
layers.GlobalAveragePooling1D(),
|
56 |
+
layers.AlphaDropout(0.5),
|
57 |
+
layers.BatchNormalization(),
|
58 |
+
],
|
59 |
+
name="swin_head",
|
60 |
+
)
|
61 |
+
|
62 |
+
# base model's (cnn model) head
|
63 |
+
self.conv_head = keras.Sequential(
|
64 |
+
[
|
65 |
+
layers.GlobalAveragePooling2D(),
|
66 |
+
layers.AlphaDropout(0.5),
|
67 |
+
],
|
68 |
+
name="conv_head",
|
69 |
+
)
|
70 |
+
|
71 |
+
# classifier
|
72 |
+
self.classifier = layers.Dense(
|
73 |
+
params.class_number, activation=None, dtype="float32"
|
74 |
+
)
|
75 |
+
self.build_graph()
|
76 |
+
|
77 |
+
def call(self, inputs, training=None, **kwargs):
|
78 |
+
x, base_gcam_top = self.new_base(inputs)
|
79 |
+
x = self.patch_extract(x)
|
80 |
+
x = self.patch_embedds(x)
|
81 |
+
x = self.swin_sequences(tf.cast(x, dtype=tf.float32))
|
82 |
+
x, swin_gcam_top = self.patch_merging(x)
|
83 |
+
|
84 |
+
swin_top = self.swin_head(x)
|
85 |
+
conv_top = self.conv_head(base_gcam_top)
|
86 |
+
preds = self.classifier(tf.concat([swin_top, conv_top], axis=-1))
|
87 |
+
|
88 |
+
if training: # training phase
|
89 |
+
return preds
|
90 |
+
else: # inference phase
|
91 |
+
return preds, base_gcam_top, swin_gcam_top
|
92 |
+
|
93 |
+
def build_graph(self):
|
94 |
+
x = keras.Input(shape=(params.image_size, params.image_size, 3))
|
95 |
+
return keras.Model(inputs=[x], outputs=self.call(x))
|
96 |
+
|
97 |
+
|
98 |
+
class GradientAccumulation(HybridSwinTransformer):
|
99 |
+
"""ref: https://gist.github.com/innat/ba6740293e7b7b227829790686f2119c"""
|
100 |
+
|
101 |
+
def __init__(self, n_gradients, **kwargs):
|
102 |
+
super().__init__(**kwargs)
|
103 |
+
self.n_gradients = tf.constant(n_gradients, dtype=tf.int32)
|
104 |
+
self.n_acum_step = tf.Variable(0, dtype=tf.int32, trainable=False)
|
105 |
+
self.gradient_accumulation = [
|
106 |
+
tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False)
|
107 |
+
for v in self.trainable_variables
|
108 |
+
]
|
109 |
+
|
110 |
+
def train_step(self, data):
|
111 |
+
# track accumulation step update
|
112 |
+
self.n_acum_step.assign_add(1)
|
113 |
+
|
114 |
+
# Unpack the data. Its structure depends on your model and
|
115 |
+
# on what you pass to `fit()`.
|
116 |
+
x, y = data
|
117 |
+
|
118 |
+
with tf.GradientTape() as tape:
|
119 |
+
y_pred = self(x, training=True) # Forward pass
|
120 |
+
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
|
121 |
+
|
122 |
+
# Calculate batch gradients
|
123 |
+
gradients = tape.gradient(loss, self.trainable_variables)
|
124 |
+
|
125 |
+
# Accumulate batch gradients
|
126 |
+
for i in range(len(self.gradient_accumulation)):
|
127 |
+
self.gradient_accumulation[i].assign_add(gradients[i])
|
128 |
+
|
129 |
+
# If n_acum_step reach the n_gradients then we apply accumulated gradients to -
|
130 |
+
# update the variables otherwise do nothing
|
131 |
+
tf.cond(
|
132 |
+
tf.equal(self.n_acum_step, self.n_gradients),
|
133 |
+
self.apply_accu_gradients,
|
134 |
+
lambda: None,
|
135 |
+
)
|
136 |
+
|
137 |
+
# Return a dict mapping metric names to current value.
|
138 |
+
# Note that it will include the loss (tracked in self.metrics).
|
139 |
+
self.compiled_metrics.update_state(y, y_pred)
|
140 |
+
return {m.name: m.result() for m in self.metrics}
|
141 |
+
|
142 |
+
def apply_accu_gradients(self):
|
143 |
+
# Update weights
|
144 |
+
self.optimizer.apply_gradients(
|
145 |
+
zip(self.gradient_accumulation, self.trainable_variables)
|
146 |
+
)
|
147 |
+
|
148 |
+
# reset accumulation step
|
149 |
+
self.n_acum_step.assign(0)
|
150 |
+
for i in range(len(self.gradient_accumulation)):
|
151 |
+
self.gradient_accumulation[i].assign(
|
152 |
+
tf.zeros_like(self.trainable_variables[i], dtype=tf.float32)
|
153 |
+
)
|
154 |
+
|
155 |
+
def test_step(self, data):
|
156 |
+
# Unpack the data
|
157 |
+
x, y = data
|
158 |
+
|
159 |
+
# Compute predictions
|
160 |
+
y_pred, base_gcam_top, swin_gcam_top = self(x, training=False)
|
161 |
+
|
162 |
+
# Updates the metrics tracking the loss
|
163 |
+
self.compiled_loss(y, y_pred, regularization_losses=self.losses)
|
164 |
+
|
165 |
+
# Update the metrics.
|
166 |
+
self.compiled_metrics.update_state(y, y_pred)
|
167 |
+
|
168 |
+
# Return a dict mapping metric names to current value.
|
169 |
+
# Note that it will include the loss (tracked in self.metrics).
|
170 |
+
return {m.name: m.result() for m in self.metrics}
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow==2.6.4
|
2 |
+
jax==0.3.13
|
3 |
+
jaxlib
|
4 |
+
numpy
|
5 |
+
matplotlib==3.5.2
|
6 |
+
gradio==3.0.15
|
7 |
+
gdown==4.4.0
|
utils/__init__.py
ADDED
File without changes
|
utils/drop_path.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import backend
|
3 |
+
from tensorflow.keras import layers
|
4 |
+
|
5 |
+
|
6 |
+
class DropPath(layers.Layer):
|
7 |
+
def __init__(self, drop_prob=None, **kwargs):
|
8 |
+
super(DropPath, self).__init__(**kwargs)
|
9 |
+
self.drop_prob = drop_prob
|
10 |
+
|
11 |
+
def call(self, inputs, training=None):
|
12 |
+
if self.drop_prob == 0.0 or not training:
|
13 |
+
return inputs
|
14 |
+
else:
|
15 |
+
batch_size = tf.shape(inputs)[0]
|
16 |
+
keep_prob = 1 - self.drop_prob
|
17 |
+
path_mask_shape = (batch_size,) + (1,) * (len(tf.shape(inputs)) - 1)
|
18 |
+
path_mask = tf.floor(backend.random_bernoulli(path_mask_shape, p=keep_prob))
|
19 |
+
outputs = (
|
20 |
+
tf.math.divide(tf.cast(inputs, dtype=tf.float32), keep_prob) * path_mask
|
21 |
+
)
|
22 |
+
return outputs
|
23 |
+
|
24 |
+
def get_config(self):
|
25 |
+
config = super().get_config()
|
26 |
+
config.update(
|
27 |
+
{
|
28 |
+
"drop_prob": self.drop_prob,
|
29 |
+
}
|
30 |
+
)
|
31 |
+
return config
|
utils/model_utils.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
|
4 |
+
|
5 |
+
class Parameters:
|
6 |
+
# data level
|
7 |
+
image_count = 3670
|
8 |
+
image_size = 384
|
9 |
+
batch_size = 12
|
10 |
+
num_grad_accumulation = 8
|
11 |
+
class_number = 5
|
12 |
+
val_split = 0.2
|
13 |
+
autotune = tf.data.AUTOTUNE
|
14 |
+
|
15 |
+
# hparams
|
16 |
+
epochs = 10
|
17 |
+
lr_sched = "cosine_restart"
|
18 |
+
lr_base = 0.016
|
19 |
+
lr_min = 0
|
20 |
+
lr_decay_epoch = 2.4
|
21 |
+
lr_warmup_epoch = 5
|
22 |
+
lr_decay_factor = 0.97
|
23 |
+
|
24 |
+
scaled_lr = lr_base * (batch_size / 256.0)
|
25 |
+
scaled_lr_min = lr_min * (batch_size / 256.0)
|
26 |
+
num_validation_sample = int(image_count * val_split)
|
27 |
+
num_training_sample = image_count - num_validation_sample
|
28 |
+
train_step = int(np.ceil(num_training_sample / float(batch_size)))
|
29 |
+
total_steps = train_step * epochs
|
30 |
+
|
31 |
+
|
32 |
+
params = Parameters()
|
33 |
+
|
34 |
+
|
35 |
+
patch_size = (2, 2) # 4-by-4 sized patches
|
36 |
+
dropout_rate = 0.5 # Dropout rate
|
37 |
+
num_heads = 8 # Attention heads
|
38 |
+
embed_dim = 64 # Embedding dimension
|
39 |
+
num_mlp = 128 # MLP layer size
|
40 |
+
qkv_bias = True # Convert embedded patches to query, key, and values with a learnable additive value
|
41 |
+
window_size = 2 # Size of attention window
|
42 |
+
shift_size = 1 # Size of shifting window
|
43 |
+
image_dimension = 24 # Initial image size / Input size of the transformer model
|
44 |
+
|
45 |
+
num_patch_x = image_dimension // patch_size[0]
|
46 |
+
num_patch_y = image_dimension // patch_size[1]
|
utils/patch.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import layers
|
3 |
+
|
4 |
+
|
5 |
+
class PatchExtract(layers.Layer):
|
6 |
+
def __init__(self, patch_size, **kwargs):
|
7 |
+
super().__init__(**kwargs)
|
8 |
+
self.patch_size_x = patch_size[0]
|
9 |
+
self.patch_size_y = patch_size[0]
|
10 |
+
|
11 |
+
def call(self, images):
|
12 |
+
batch_size = tf.shape(images)[0]
|
13 |
+
patches = tf.image.extract_patches(
|
14 |
+
images=images,
|
15 |
+
sizes=(1, self.patch_size_x, self.patch_size_y, 1),
|
16 |
+
strides=(1, self.patch_size_x, self.patch_size_y, 1),
|
17 |
+
rates=(1, 1, 1, 1),
|
18 |
+
padding="VALID",
|
19 |
+
)
|
20 |
+
patch_dim = patches.shape[-1]
|
21 |
+
patch_num = patches.shape[1]
|
22 |
+
return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))
|
23 |
+
|
24 |
+
def get_config(self):
|
25 |
+
config = super().get_config()
|
26 |
+
config.update(
|
27 |
+
{
|
28 |
+
"patch_size_y": self.patch_size_y,
|
29 |
+
"patch_size_x": self.patch_size_x,
|
30 |
+
}
|
31 |
+
)
|
32 |
+
return config
|
33 |
+
|
34 |
+
|
35 |
+
class PatchEmbedding(layers.Layer):
|
36 |
+
def __init__(self, num_patch, embed_dim, **kwargs):
|
37 |
+
super().__init__(**kwargs)
|
38 |
+
self.num_patch = num_patch
|
39 |
+
self.proj = layers.Dense(embed_dim)
|
40 |
+
self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)
|
41 |
+
|
42 |
+
def call(self, patch):
|
43 |
+
pos = tf.range(start=0, limit=self.num_patch, delta=1)
|
44 |
+
return self.proj(patch) + self.pos_embed(pos)
|
45 |
+
|
46 |
+
def get_config(self):
|
47 |
+
config = super().get_config()
|
48 |
+
config.update(
|
49 |
+
{
|
50 |
+
"num_patch": self.num_patch,
|
51 |
+
}
|
52 |
+
)
|
53 |
+
return config
|
54 |
+
|
55 |
+
|
56 |
+
class PatchMerging(layers.Layer):
|
57 |
+
def __init__(self, num_patch, embed_dim):
|
58 |
+
super().__init__()
|
59 |
+
self.num_patch = num_patch
|
60 |
+
self.embed_dim = embed_dim
|
61 |
+
self.linear_trans = layers.Dense(2 * embed_dim, use_bias=False)
|
62 |
+
|
63 |
+
def call(self, x):
|
64 |
+
height, width = self.num_patch
|
65 |
+
_, _, C = x.get_shape().as_list()
|
66 |
+
x = tf.reshape(x, shape=(-1, height, width, C))
|
67 |
+
feat_maps = x
|
68 |
+
|
69 |
+
x0 = x[:, 0::2, 0::2, :]
|
70 |
+
x1 = x[:, 1::2, 0::2, :]
|
71 |
+
x2 = x[:, 0::2, 1::2, :]
|
72 |
+
x3 = x[:, 1::2, 1::2, :]
|
73 |
+
x = tf.concat((x0, x1, x2, x3), axis=-1)
|
74 |
+
x = tf.reshape(x, shape=(-1, (height // 2) * (width // 2), 4 * C))
|
75 |
+
return self.linear_trans(x), feat_maps
|
76 |
+
|
77 |
+
def get_config(self):
|
78 |
+
config = super().get_config()
|
79 |
+
config.update({"num_patch": self.num_patch, "embed_dim": self.embed_dim})
|
80 |
+
return config
|
utils/swin_window.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
|
4 |
+
def window_partition(x, window_size):
|
5 |
+
_, height, width, channels = x.shape
|
6 |
+
patch_num_y = height // window_size
|
7 |
+
patch_num_x = width // window_size
|
8 |
+
x = tf.reshape(
|
9 |
+
x, shape=(-1, patch_num_y, window_size, patch_num_x, window_size, channels)
|
10 |
+
)
|
11 |
+
x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
|
12 |
+
windows = tf.reshape(x, shape=(-1, window_size, window_size, channels))
|
13 |
+
return windows
|
14 |
+
|
15 |
+
|
16 |
+
def window_reverse(windows, window_size, height, width, channels):
|
17 |
+
patch_num_y = height // window_size
|
18 |
+
patch_num_x = width // window_size
|
19 |
+
x = tf.reshape(
|
20 |
+
windows,
|
21 |
+
shape=(-1, patch_num_y, patch_num_x, window_size, window_size, channels),
|
22 |
+
)
|
23 |
+
x = tf.transpose(x, perm=(0, 1, 3, 2, 4, 5))
|
24 |
+
x = tf.reshape(x, shape=(-1, height, width, channels))
|
25 |
+
return x
|
utils/viz_utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.cm as cm
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow import keras
|
5 |
+
|
6 |
+
|
7 |
+
def make_gradcam_heatmap(img_array, grad_model, pred_index=None):
|
8 |
+
with tf.GradientTape(persistent=True) as tape:
|
9 |
+
preds, base_top, swin_top = grad_model(img_array)
|
10 |
+
if pred_index is None:
|
11 |
+
pred_index = tf.argmax(preds[0])
|
12 |
+
class_channel = preds[:, pred_index]
|
13 |
+
|
14 |
+
grads = tape.gradient(class_channel, base_top)
|
15 |
+
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
|
16 |
+
base_top = base_top[0]
|
17 |
+
heatmap_a = base_top @ pooled_grads[..., tf.newaxis]
|
18 |
+
heatmap_a = tf.squeeze(heatmap_a)
|
19 |
+
heatmap_a = tf.maximum(heatmap_a, 0) / tf.math.reduce_max(heatmap_a)
|
20 |
+
heatmap_a = heatmap_a.numpy()
|
21 |
+
|
22 |
+
grads = tape.gradient(class_channel, swin_top)
|
23 |
+
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
|
24 |
+
swin_top = swin_top[0]
|
25 |
+
heatmap_b = swin_top @ pooled_grads[..., tf.newaxis]
|
26 |
+
heatmap_b = tf.squeeze(heatmap_b)
|
27 |
+
heatmap_b = tf.maximum(heatmap_b, 0) / tf.math.reduce_max(heatmap_b)
|
28 |
+
heatmap_b = heatmap_b.numpy()
|
29 |
+
return heatmap_a, heatmap_b, preds
|
30 |
+
|
31 |
+
|
32 |
+
def save_and_display_gradcam(
|
33 |
+
img,
|
34 |
+
heatmap,
|
35 |
+
target=None,
|
36 |
+
pred=None,
|
37 |
+
cam_path="cam.jpg",
|
38 |
+
cmap="jet", # inferno, viridis
|
39 |
+
alpha=0.6,
|
40 |
+
plot=None,
|
41 |
+
image_shape=None,
|
42 |
+
):
|
43 |
+
# Rescale heatmap to a range 0-255
|
44 |
+
heatmap = np.uint8(255 * heatmap)
|
45 |
+
|
46 |
+
# Use jet colormap to colorize heatmap
|
47 |
+
jet = cm.get_cmap(cmap)
|
48 |
+
|
49 |
+
# Use RGB values of the colormap
|
50 |
+
jet_colors = jet(np.arange(256))[:, :3]
|
51 |
+
jet_heatmap = jet_colors[heatmap]
|
52 |
+
|
53 |
+
# Create an image with RGB colorized heatmap
|
54 |
+
jet_heatmap = keras.utils.array_to_img(jet_heatmap)
|
55 |
+
jet_heatmap = jet_heatmap.resize((img.shape[0], img.shape[1]))
|
56 |
+
jet_heatmap = keras.utils.img_to_array(jet_heatmap)
|
57 |
+
|
58 |
+
# Superimpose the heatmap on original image
|
59 |
+
superimposed_img = img + jet_heatmap * alpha
|
60 |
+
superimposed_img = keras.utils.array_to_img(superimposed_img)
|
61 |
+
|
62 |
+
size_w, size_h = image_shape[:2]
|
63 |
+
superimposed_img = superimposed_img.resize((size_h, size_w))
|
64 |
+
return superimposed_img
|