Caleb Spradlin
changed pre-trained
0a97993
import streamlit as st
from pathlib import Path
# -----------------------------------------------------------------------------
# main
# -----------------------------------------------------------------------------
def main():
st.title("SatVision Few-Shot Comparison")
st.write("")
selected_option = st.selectbox(
"Number of training samples", [10, 1000, 5000]
)
st.markdown(
"Move slider to select how many training "
+ "samples the models were trained on"
)
images = load_images(selected_option, Path("./images/images"))
labels = load_labels(selected_option, Path("./images/labels"))
preds = load_predictions(selected_option, Path("./images/predictions"))
zipped_st_images = zip(images, preds["svb"], preds["unet"], preds["unet-ls"], labels)
st.write("")
titleCol0, titleCol1, titleCol2, titleCol3, titleCol4 = st.columns(5)
titleCol0.markdown(f"### MOD09GA [3-2-1] Image Chip")
titleCol1.markdown(f"### SatVision-B Prediction")
titleCol2.markdown(f"### UNet (CNN) Prediction")
titleCol3.markdown(f'### UNet (CNN) LS Pretrained Prediction')
titleCol4.markdown(f"### MCD12Q1 LandCover Target")
st.write("")
grid = make_grid(5, 5)
for i, (image_data, svb_data, unet_data, unet_ls_data, label_data) in enumerate(zipped_st_images):
# if i == 0:
# grid[0][0].markdown(f'## MOD09GA 3-2-1 Image Chip')
# grid[0][1].markdown(f'## SatVision-B Prediction')
# grid[0][2].markdown(f'## UNet (CNN) Prediction')
# grid[0][3].markdown(f'## MCD12Q1 LandCover Target')
grid[i][0].image(image_data[0], image_data[1], use_column_width=True)
grid[i][1].image(svb_data[0], svb_data[1], use_column_width=True)
grid[i][2].image(unet_data[0], unet_data[1], use_column_width=True)
grid[i][3].image(unet_ls_data[0], unet_ls_data[1], use_column_width=True)
grid[i][4].image(label_data[0], label_data[1], use_column_width=True)
st.markdown("### Few-Shot Learning with SatVision-Base")
description = (
"Pre-trained vision transformers (we use SwinV2) offers a "
+ "good advantage when looking to apply a model to a task with very little"
+ " labeled training data. We pre-trained SatVision-Base on 26 million "
+ " MODIS Surface Reflectance image patches. This allows the "
+ " SatVision-Base models to learn relevant features and representations"
+ " from a diverse range of scenes. This knowledge can be transferred to a"
+ " few-shot learning task, enabling the model to leverage its"
+ " understanding of spatial patterns, textures, and contextual information"
)
st.markdown(description)
# -----------------------------------------------------------------------------
# load_images
# -----------------------------------------------------------------------------
def load_images(selected_option: str, image_dir: Path):
"""
Given a selected option and image dir, return streamlit image objects.
"""
image_paths = find_images(selected_option, image_dir)
images = [
(str(path), f"MOD09GA 3-2-1 H18v04 2019 Example {i}")
for i, path in enumerate(image_paths, 1)
]
return images
# -----------------------------------------------------------------------------
# find_images
# -----------------------------------------------------------------------------
def find_images(selected_option: str, image_dir: Path):
images_regex = f"ft_demo_{selected_option}_*_img.png"
images_matching_regex = sorted(image_dir.glob(images_regex))
assert len(images_matching_regex) == 3, "Should be 3 images matching regex"
assert "1071" in str(images_matching_regex[0]), "Should be 1071"
return images_matching_regex
# -----------------------------------------------------------------------------
# load_labels
# -----------------------------------------------------------------------------
def load_labels(selected_option, label_dir: Path):
label_paths = find_labels(selected_option, label_dir)
labels = [
(str(path), f"MCD12Q1 LandCover Target Example {i}")
for i, path in enumerate(label_paths, 1)
]
return labels
# -----------------------------------------------------------------------------
# find_labels
# -----------------------------------------------------------------------------
def find_labels(selected_option: str, label_dir: Path):
labels_regex = f"ft_demo_{selected_option}_*_label.png"
labels_matching_regex = sorted(label_dir.glob(labels_regex))
assert len(labels_matching_regex) == 3, "Should be 3 label images matching regex"
assert "1071" in str(labels_matching_regex[0]), "Should be 1071"
return labels_matching_regex
# -----------------------------------------------------------------------------
# load_predictions
# -----------------------------------------------------------------------------
def load_predictions(selected_option: str, pred_dir: Path):
svb_pred_paths = find_preds(selected_option, pred_dir, "svb")
unet_pred_paths = find_preds(selected_option, pred_dir, "cnn")
unet_ls_pred_paths = find_preds(selected_option, pred_dir, "cnn-ls")
svb_preds = [
(str(path), f"SatVision-B Prediction Example {i}")
for i, path in enumerate(svb_pred_paths, 1)
]
unet_preds = [
(str(path), f"Unet Prediction Example {i}")
for i, path in enumerate(unet_pred_paths, 1)
]
unet_ls_preds = [
(str(path), f"Unet LS Pre-trained Prediction Example {i}")
for i, path in enumerate(unet_ls_pred_paths, 1)
]
prediction_dict = {"svb": svb_preds, "unet": unet_preds, "unet-ls": unet_ls_preds}
return prediction_dict
# -----------------------------------------------------------------------------
# find_preds
# -----------------------------------------------------------------------------
def find_preds(selected_option: int, pred_dir: Path, model: str):
if model == "cnn":
pred_regex = f"ft_cnn_demo_{selected_option}_*cnn-plain_pred.png"
elif model == "cnn-ls":
pred_regex = f"ft_cnn_demo_{selected_option}_*cnn-ls_pred.png"
else:
pred_regex = f"ft_demo_{selected_option}_*_pred.png"
model_specific_dir = pred_dir / str(selected_option) / model
assert model_specific_dir.exists(), f"{model_specific_dir} does not exist"
preds_matching_regex = sorted(model_specific_dir.glob(pred_regex))
assert (
len(preds_matching_regex) == 3
), "Should be 3 prediction images matching regex"
assert "1071" in str(preds_matching_regex[0]), "Should be 1071"
return preds_matching_regex
# -----------------------------------------------------------------------------
# make_grid
# -----------------------------------------------------------------------------
def make_grid(cols, rows):
grid = [0] * cols
for i in range(cols):
with st.container():
grid[i] = st.columns(rows, gap="large")
return grid
# -----------------------------------------------------------------------------
# Main execution
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main()