|
import streamlit as st |
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
st.title("SatVision Few-Shot Comparison") |
|
|
|
st.write("") |
|
|
|
selected_option = st.selectbox( |
|
"Number of training samples", [10, 1000, 5000] |
|
) |
|
st.markdown( |
|
"Move slider to select how many training " |
|
+ "samples the models were trained on" |
|
) |
|
|
|
images = load_images(selected_option, Path("./images/images")) |
|
|
|
labels = load_labels(selected_option, Path("./images/labels")) |
|
|
|
preds = load_predictions(selected_option, Path("./images/predictions")) |
|
|
|
zipped_st_images = zip(images, preds["svb"], preds["unet"], preds["unet-ls"], labels) |
|
|
|
st.write("") |
|
|
|
titleCol0, titleCol1, titleCol2, titleCol3, titleCol4 = st.columns(5) |
|
|
|
titleCol0.markdown(f"### MOD09GA [3-2-1] Image Chip") |
|
titleCol1.markdown(f"### SatVision-B Prediction") |
|
titleCol2.markdown(f"### UNet (CNN) Prediction") |
|
titleCol3.markdown(f'### UNet (CNN) LS Pretrained Prediction') |
|
titleCol4.markdown(f"### MCD12Q1 LandCover Target") |
|
|
|
st.write("") |
|
|
|
grid = make_grid(5, 5) |
|
|
|
for i, (image_data, svb_data, unet_data, unet_ls_data, label_data) in enumerate(zipped_st_images): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
grid[i][0].image(image_data[0], image_data[1], use_column_width=True) |
|
grid[i][1].image(svb_data[0], svb_data[1], use_column_width=True) |
|
grid[i][2].image(unet_data[0], unet_data[1], use_column_width=True) |
|
grid[i][3].image(unet_ls_data[0], unet_ls_data[1], use_column_width=True) |
|
grid[i][4].image(label_data[0], label_data[1], use_column_width=True) |
|
|
|
st.markdown("### Few-Shot Learning with SatVision-Base") |
|
description = ( |
|
"Pre-trained vision transformers (we use SwinV2) offers a " |
|
+ "good advantage when looking to apply a model to a task with very little" |
|
+ " labeled training data. We pre-trained SatVision-Base on 26 million " |
|
+ " MODIS Surface Reflectance image patches. This allows the " |
|
+ " SatVision-Base models to learn relevant features and representations" |
|
+ " from a diverse range of scenes. This knowledge can be transferred to a" |
|
+ " few-shot learning task, enabling the model to leverage its" |
|
+ " understanding of spatial patterns, textures, and contextual information" |
|
) |
|
st.markdown(description) |
|
|
|
|
|
|
|
|
|
|
|
def load_images(selected_option: str, image_dir: Path): |
|
""" |
|
Given a selected option and image dir, return streamlit image objects. |
|
""" |
|
|
|
image_paths = find_images(selected_option, image_dir) |
|
|
|
images = [ |
|
(str(path), f"MOD09GA 3-2-1 H18v04 2019 Example {i}") |
|
for i, path in enumerate(image_paths, 1) |
|
] |
|
|
|
return images |
|
|
|
|
|
|
|
|
|
|
|
def find_images(selected_option: str, image_dir: Path): |
|
images_regex = f"ft_demo_{selected_option}_*_img.png" |
|
|
|
images_matching_regex = sorted(image_dir.glob(images_regex)) |
|
|
|
assert len(images_matching_regex) == 3, "Should be 3 images matching regex" |
|
|
|
assert "1071" in str(images_matching_regex[0]), "Should be 1071" |
|
|
|
return images_matching_regex |
|
|
|
|
|
|
|
|
|
|
|
def load_labels(selected_option, label_dir: Path): |
|
label_paths = find_labels(selected_option, label_dir) |
|
|
|
labels = [ |
|
(str(path), f"MCD12Q1 LandCover Target Example {i}") |
|
for i, path in enumerate(label_paths, 1) |
|
] |
|
|
|
return labels |
|
|
|
|
|
|
|
|
|
|
|
def find_labels(selected_option: str, label_dir: Path): |
|
labels_regex = f"ft_demo_{selected_option}_*_label.png" |
|
|
|
labels_matching_regex = sorted(label_dir.glob(labels_regex)) |
|
|
|
assert len(labels_matching_regex) == 3, "Should be 3 label images matching regex" |
|
|
|
assert "1071" in str(labels_matching_regex[0]), "Should be 1071" |
|
|
|
return labels_matching_regex |
|
|
|
|
|
|
|
|
|
|
|
def load_predictions(selected_option: str, pred_dir: Path): |
|
svb_pred_paths = find_preds(selected_option, pred_dir, "svb") |
|
|
|
unet_pred_paths = find_preds(selected_option, pred_dir, "cnn") |
|
unet_ls_pred_paths = find_preds(selected_option, pred_dir, "cnn-ls") |
|
|
|
svb_preds = [ |
|
(str(path), f"SatVision-B Prediction Example {i}") |
|
for i, path in enumerate(svb_pred_paths, 1) |
|
] |
|
|
|
unet_preds = [ |
|
(str(path), f"Unet Prediction Example {i}") |
|
for i, path in enumerate(unet_pred_paths, 1) |
|
] |
|
|
|
unet_ls_preds = [ |
|
(str(path), f"Unet LS Pre-trained Prediction Example {i}") |
|
for i, path in enumerate(unet_ls_pred_paths, 1) |
|
] |
|
|
|
prediction_dict = {"svb": svb_preds, "unet": unet_preds, "unet-ls": unet_ls_preds} |
|
|
|
return prediction_dict |
|
|
|
|
|
|
|
|
|
|
|
def find_preds(selected_option: int, pred_dir: Path, model: str): |
|
|
|
if model == "cnn": |
|
pred_regex = f"ft_cnn_demo_{selected_option}_*cnn-plain_pred.png" |
|
|
|
elif model == "cnn-ls": |
|
pred_regex = f"ft_cnn_demo_{selected_option}_*cnn-ls_pred.png" |
|
|
|
else: |
|
pred_regex = f"ft_demo_{selected_option}_*_pred.png" |
|
|
|
model_specific_dir = pred_dir / str(selected_option) / model |
|
|
|
assert model_specific_dir.exists(), f"{model_specific_dir} does not exist" |
|
|
|
preds_matching_regex = sorted(model_specific_dir.glob(pred_regex)) |
|
|
|
assert ( |
|
len(preds_matching_regex) == 3 |
|
), "Should be 3 prediction images matching regex" |
|
|
|
assert "1071" in str(preds_matching_regex[0]), "Should be 1071" |
|
|
|
return preds_matching_regex |
|
|
|
|
|
|
|
|
|
|
|
def make_grid(cols, rows): |
|
grid = [0] * cols |
|
|
|
for i in range(cols): |
|
with st.container(): |
|
grid[i] = st.columns(rows, gap="large") |
|
|
|
return grid |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|