alexanderkroner commited on
Commit
64f2adb
·
verified ·
1 Parent(s): cfa4919

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +139 -0
README.md CHANGED
@@ -21,6 +21,145 @@ MSI-Net is a visual saliency model that predicts where humans fixate on natural
21
 
22
  <img src="https://github.com/alexanderkroner/saliency/blob/master/figures/architecture.jpg?raw=true" width="700">
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Datasets
25
 
26
  Before training the model on fixation data, the encoder weights were initialized from a VGG16 backbone pre-trained on the ImageNet classification task. The model was then trained on the SALICON dataset, which consists of mouse movement recordings as a proxy for gaze measurements. Finally, the weights can be fine-tuned on human eye tracking data. MSI-Net was therefore also trained on one of the following datasets, although here we only provide the SALICON base model:
 
21
 
22
  <img src="https://github.com/alexanderkroner/saliency/blob/master/figures/architecture.jpg?raw=true" width="700">
23
 
24
+ # Example Use
25
+
26
+ ### Import the dependencies
27
+ ```python
28
+ import matplotlib.pyplot as plt
29
+ import numpy as np
30
+ import tensorflow as tf
31
+ from huggingface_hub import snapshot_download
32
+ ```
33
+
34
+ ### Download the repo files
35
+ ```python
36
+ hf_dir = snapshot_download(
37
+ repo_id="alexanderkroner/MSI-Net", allow_patterns=["*.pb", "*.jpg"]
38
+ )
39
+ ```
40
+
41
+ ### Load the saliency model
42
+ ```python
43
+ model = tf.saved_model.load(hf_dir)
44
+ model = model.signatures["serving_default"]
45
+ ```
46
+
47
+ ### Load the functions for preprocessing the input and postprocessing the output
48
+ ```python
49
+ def get_target_shape(original_shape):
50
+ original_aspect_ratio = original_shape[0] / original_shape[1]
51
+
52
+ square_mode = abs(original_aspect_ratio - 1.0)
53
+ landscape_mode = abs(original_aspect_ratio - 240 / 320)
54
+ portrait_mode = abs(original_aspect_ratio - 320 / 240)
55
+
56
+ best_mode = min(square_mode, landscape_mode, portrait_mode)
57
+
58
+ if best_mode == square_mode:
59
+ target_shape = (320, 320)
60
+ elif best_mode == landscape_mode:
61
+ target_shape = (240, 320)
62
+ else:
63
+ target_shape = (320, 240)
64
+
65
+ return target_shape
66
+
67
+
68
+ def preprocess_input(input_image, target_shape):
69
+ input_tensor = tf.expand_dims(input_image, axis=0)
70
+
71
+ input_tensor = tf.image.resize(
72
+ input_tensor, target_shape, preserve_aspect_ratio=True
73
+ )
74
+
75
+ vertical_padding = target_shape[0] - input_tensor.shape[1]
76
+ horizontal_padding = target_shape[1] - input_tensor.shape[2]
77
+
78
+ vertical_padding_1 = vertical_padding // 2
79
+ vertical_padding_2 = vertical_padding - vertical_padding_1
80
+
81
+ horizontal_padding_1 = horizontal_padding // 2
82
+ horizontal_padding_2 = horizontal_padding - horizontal_padding_1
83
+
84
+ input_tensor = tf.pad(
85
+ input_tensor,
86
+ [
87
+ [0, 0],
88
+ [vertical_padding_1, vertical_padding_2],
89
+ [horizontal_padding_1, horizontal_padding_2],
90
+ [0, 0],
91
+ ],
92
+ )
93
+
94
+ return (
95
+ input_tensor,
96
+ [vertical_padding_1, vertical_padding_2],
97
+ [horizontal_padding_1, horizontal_padding_2],
98
+ )
99
+
100
+
101
+ def postprocess_output(
102
+ output_tensor, vertical_padding, horizontal_padding, original_shape
103
+ ):
104
+ output_tensor = output_tensor[
105
+ :,
106
+ vertical_padding[0] : output_tensor.shape[1] - vertical_padding[1],
107
+ horizontal_padding[0] : output_tensor.shape[2] - horizontal_padding[1],
108
+ :,
109
+ ]
110
+
111
+ output_tensor = tf.image.resize(output_tensor, original_shape)
112
+
113
+ output_array = output_tensor.numpy().squeeze()
114
+ output_array = plt.cm.inferno(output_array)[..., :3]
115
+
116
+ return output_array
117
+ ```
118
+
119
+ ### Load and preprocess an example image
120
+ ```python
121
+ input_image = tf.keras.utils.load_img(hf_dir + "/example.jpg")
122
+ input_image = np.array(input_image, dtype=np.float32)
123
+
124
+ original_shape = input_image.shape[:2]
125
+ target_shape = get_target_shape(original_shape)
126
+
127
+ input_tensor, vertical_padding, horizontal_padding = preprocess_input(
128
+ input_image, target_shape
129
+ )
130
+ ```
131
+
132
+ ### Feed the input tensor to the model
133
+ ```python
134
+ output_tensor = model(input_tensor)["output"]
135
+ ```
136
+
137
+ ### Postprocess and visualize the output
138
+ ```python
139
+ saliency_map = postprocess_output(
140
+ output_tensor, vertical_padding, horizontal_padding, original_shape
141
+ )
142
+
143
+ alpha = 0.65
144
+
145
+ blended_image = alpha * saliency_map + (1 - alpha) * input_image / 255
146
+
147
+ plt.figure(figsize=(10, 5))
148
+
149
+ plt.subplot(1, 2, 1)
150
+ plt.imshow(input_image / 255)
151
+ plt.title("Input Image")
152
+ plt.axis("off")
153
+
154
+ plt.subplot(1, 2, 2)
155
+ plt.imshow(blended_image)
156
+ plt.title("Saliency Map")
157
+ plt.axis("off")
158
+
159
+ plt.tight_layout()
160
+ plt.show()
161
+ ```
162
+
163
  # Datasets
164
 
165
  Before training the model on fixation data, the encoder weights were initialized from a VGG16 backbone pre-trained on the ImageNet classification task. The model was then trained on the SALICON dataset, which consists of mouse movement recordings as a proxy for gaze measurements. Finally, the weights can be fine-tuned on human eye tracking data. MSI-Net was therefore also trained on one of the following datasets, although here we only provide the SALICON base model: