cssprad1's picture
initial commit
ab687e7
raw
history blame
3.66 kB
import streamlit as st
import numpy as np
import os
import pathlib
from inference import infer, InferenceModel
# -----------------------------------------------------------------------------
# class SatvisionDemoApp
#
# Directory Structure: base-directory/MOD09GA/year
# MOD09GQ/year
# MYD09GA/year
# MYD09GQ/year
#
# -----------------------------------------------------------------------------
class SatvisionDemoApp:
# -------------------------------------------------------------------------
# __init__
# -------------------------------------------------------------------------
def __init__(self):
self.thumbnail_dir = pathlib.Path('data/thumbnails')
self.image_dir = pathlib.Path('data/images')
print(self.thumbnail_dir)
self.thumbnail_files = sorted(list(self.thumbnail_dir.glob('sv-*.png')))
self.image_files = sorted(list(self.image_dir.glob('sv-*.npy')))
print(list(self.image_files))
self.thumbnail_names = [str(tn_path.name) for tn_path in self.thumbnail_files]
print(self.thumbnail_names)
self.inferenceModel = InferenceModel()
# -------------------------------------------------------------------------
# render_sidebar
# -------------------------------------------------------------------------
def render_sidebar(self):
st.sidebar.header("Select an Image")
for index, thumbnail in enumerate(self.thumbnail_names):
thumbnail_path = self.thumbnail_dir / thumbnail
# thumbnail_arr = np.load(thumbnail_path)
print(str(thumbnail_path))
st.sidebar.image(str(thumbnail_path), use_column_width=True, caption=thumbnail)
# -------------------------------------------------------------------------
# render_main_app
# -------------------------------------------------------------------------
def render_main_app(self):
st.title("Satvision-Base Demo")
st.header("Image Reconstruction Process")
selected_image_index = st.sidebar.selectbox(
"Select an Image",
self.thumbnail_names)
print(selected_image_index)
selected_image = self.load_selected_image(selected_image_index)
image, masked_input, output = self.inferenceModel.infer(selected_image)
col1, col2, col3 = st.columns(3, gap="large")
# Display the selected image with a title three times side-by-side
with col1:
st.image(image, use_column_width=True, caption="Input")
with col2:
st.image(masked_input, use_column_width=True, caption="Input Masked")
with col3:
st.image(output, use_column_width=True, caption="Reconstruction")
# -------------------------------------------------------------------------
# load_selected_image
# -------------------------------------------------------------------------
def load_selected_image(self, image_name):
# Load the selected image using NumPy (replace this with your image loading code)
image_name = image_name.replace('.png', '.npy')
image = np.load(self.image_dir / image_name)
image = np.moveaxis(image, 0, 2)
return image
# -----------------------------------------------------------------------------
# main
# -----------------------------------------------------------------------------
def main():
app = SatvisionDemoApp()
app.render_main_app()
app.render_sidebar()
if __name__ == "__main__":
main()