File size: 7,305 Bytes
4d01101 18903a3 4d01101 18903a3 0a97993 18903a3 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 4d01101 0a97993 18903a3 0a97993 4d01101 18903a3 0a97993 4d01101 18903a3 0a97993 4d01101 0a97993 18903a3 4d01101 0a97993 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 0a97993 4d01101 18903a3 4d01101 18903a3 4d01101 0a97993 4d01101 0a97993 18903a3 0a97993 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 0a97993 18903a3 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 4d01101 18903a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import streamlit as st
from pathlib import Path
# -----------------------------------------------------------------------------
# main
# -----------------------------------------------------------------------------
def main():
st.title("SatVision Few-Shot Comparison")
st.write("")
selected_option = st.selectbox(
"Number of training samples", [10, 1000, 5000]
)
st.markdown(
"Move slider to select how many training "
+ "samples the models were trained on"
)
images = load_images(selected_option, Path("./images/images"))
labels = load_labels(selected_option, Path("./images/labels"))
preds = load_predictions(selected_option, Path("./images/predictions"))
zipped_st_images = zip(images, preds["svb"], preds["unet"], preds["unet-ls"], labels)
st.write("")
titleCol0, titleCol1, titleCol2, titleCol3, titleCol4 = st.columns(5)
titleCol0.markdown(f"### MOD09GA [3-2-1] Image Chip")
titleCol1.markdown(f"### SatVision-B Prediction")
titleCol2.markdown(f"### UNet (CNN) Prediction")
titleCol3.markdown(f'### UNet (CNN) LS Pretrained Prediction')
titleCol4.markdown(f"### MCD12Q1 LandCover Target")
st.write("")
grid = make_grid(5, 5)
for i, (image_data, svb_data, unet_data, unet_ls_data, label_data) in enumerate(zipped_st_images):
# if i == 0:
# grid[0][0].markdown(f'## MOD09GA 3-2-1 Image Chip')
# grid[0][1].markdown(f'## SatVision-B Prediction')
# grid[0][2].markdown(f'## UNet (CNN) Prediction')
# grid[0][3].markdown(f'## MCD12Q1 LandCover Target')
grid[i][0].image(image_data[0], image_data[1], use_column_width=True)
grid[i][1].image(svb_data[0], svb_data[1], use_column_width=True)
grid[i][2].image(unet_data[0], unet_data[1], use_column_width=True)
grid[i][3].image(unet_ls_data[0], unet_ls_data[1], use_column_width=True)
grid[i][4].image(label_data[0], label_data[1], use_column_width=True)
st.markdown("### Few-Shot Learning with SatVision-Base")
description = (
"Pre-trained vision transformers (we use SwinV2) offers a "
+ "good advantage when looking to apply a model to a task with very little"
+ " labeled training data. We pre-trained SatVision-Base on 26 million "
+ " MODIS Surface Reflectance image patches. This allows the "
+ " SatVision-Base models to learn relevant features and representations"
+ " from a diverse range of scenes. This knowledge can be transferred to a"
+ " few-shot learning task, enabling the model to leverage its"
+ " understanding of spatial patterns, textures, and contextual information"
)
st.markdown(description)
# -----------------------------------------------------------------------------
# load_images
# -----------------------------------------------------------------------------
def load_images(selected_option: str, image_dir: Path):
"""
Given a selected option and image dir, return streamlit image objects.
"""
image_paths = find_images(selected_option, image_dir)
images = [
(str(path), f"MOD09GA 3-2-1 H18v04 2019 Example {i}")
for i, path in enumerate(image_paths, 1)
]
return images
# -----------------------------------------------------------------------------
# find_images
# -----------------------------------------------------------------------------
def find_images(selected_option: str, image_dir: Path):
images_regex = f"ft_demo_{selected_option}_*_img.png"
images_matching_regex = sorted(image_dir.glob(images_regex))
assert len(images_matching_regex) == 3, "Should be 3 images matching regex"
assert "1071" in str(images_matching_regex[0]), "Should be 1071"
return images_matching_regex
# -----------------------------------------------------------------------------
# load_labels
# -----------------------------------------------------------------------------
def load_labels(selected_option, label_dir: Path):
label_paths = find_labels(selected_option, label_dir)
labels = [
(str(path), f"MCD12Q1 LandCover Target Example {i}")
for i, path in enumerate(label_paths, 1)
]
return labels
# -----------------------------------------------------------------------------
# find_labels
# -----------------------------------------------------------------------------
def find_labels(selected_option: str, label_dir: Path):
labels_regex = f"ft_demo_{selected_option}_*_label.png"
labels_matching_regex = sorted(label_dir.glob(labels_regex))
assert len(labels_matching_regex) == 3, "Should be 3 label images matching regex"
assert "1071" in str(labels_matching_regex[0]), "Should be 1071"
return labels_matching_regex
# -----------------------------------------------------------------------------
# load_predictions
# -----------------------------------------------------------------------------
def load_predictions(selected_option: str, pred_dir: Path):
svb_pred_paths = find_preds(selected_option, pred_dir, "svb")
unet_pred_paths = find_preds(selected_option, pred_dir, "cnn")
unet_ls_pred_paths = find_preds(selected_option, pred_dir, "cnn-ls")
svb_preds = [
(str(path), f"SatVision-B Prediction Example {i}")
for i, path in enumerate(svb_pred_paths, 1)
]
unet_preds = [
(str(path), f"Unet Prediction Example {i}")
for i, path in enumerate(unet_pred_paths, 1)
]
unet_ls_preds = [
(str(path), f"Unet LS Pre-trained Prediction Example {i}")
for i, path in enumerate(unet_ls_pred_paths, 1)
]
prediction_dict = {"svb": svb_preds, "unet": unet_preds, "unet-ls": unet_ls_preds}
return prediction_dict
# -----------------------------------------------------------------------------
# find_preds
# -----------------------------------------------------------------------------
def find_preds(selected_option: int, pred_dir: Path, model: str):
if model == "cnn":
pred_regex = f"ft_cnn_demo_{selected_option}_*cnn-plain_pred.png"
elif model == "cnn-ls":
pred_regex = f"ft_cnn_demo_{selected_option}_*cnn-ls_pred.png"
else:
pred_regex = f"ft_demo_{selected_option}_*_pred.png"
model_specific_dir = pred_dir / str(selected_option) / model
assert model_specific_dir.exists(), f"{model_specific_dir} does not exist"
preds_matching_regex = sorted(model_specific_dir.glob(pred_regex))
assert (
len(preds_matching_regex) == 3
), "Should be 3 prediction images matching regex"
assert "1071" in str(preds_matching_regex[0]), "Should be 1071"
return preds_matching_regex
# -----------------------------------------------------------------------------
# make_grid
# -----------------------------------------------------------------------------
def make_grid(cols, rows):
grid = [0] * cols
for i in range(cols):
with st.container():
grid[i] = st.columns(rows, gap="large")
return grid
# -----------------------------------------------------------------------------
# Main execution
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main()
|