Spaces:
Sleeping
Sleeping
from shiny import App, ui, render, reactive | |
import os | |
import numpy as np | |
import torch | |
from PIL import Image | |
from transformers import SamModel, SamProcessor | |
# Load the processor and the finetuned model | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
model_path = "mito_model_checkpoint.pth" | |
model = SamModel.from_pretrained("facebook/sam-vit-base") | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
model.eval() | |
def process_image(image_path): | |
# Open and prepare the image | |
image = Image.open(image_path).convert("RGB") # Ensure RGB format for consistency | |
image_np = np.array(image) | |
# Prepare the image for the model using the processor | |
inputs = processor(images=image_np, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(**inputs, multimask_output=False) | |
# Process the prediction to create a binary mask | |
pred_masks = torch.sigmoid(outputs.pred_masks).cpu().numpy() | |
segmented_image = (pred_masks[0] > .99).astype(np.uint8) * 255 | |
print(segmented_image) | |
# Save the segmented image | |
root, ext = os.path.splitext(image_path) | |
output_path = f"{root}_segmented.png" | |
segmented_image_pil = Image.fromarray(segmented_image.squeeze(), mode="L") | |
segmented_image_pil.save(output_path) | |
return output_path | |
# Define the Shiny app UI layout | |
app_ui = ui.page_fluid( | |
ui.layout_sidebar( | |
ui.panel_sidebar( | |
ui.input_file("image_upload", "Upload Satellite Image", accept=".jpg,.jpeg,.png,.tif") | |
), | |
ui.panel_main( | |
ui.output_image("uploaded_image", "Uploaded Image"), | |
ui.output_image("segmented_image", "Segmented Image") | |
) | |
) | |
) | |
def server(input, output, session): | |
def uploaded_image(): | |
file_info = input.image_upload() | |
if file_info: | |
if isinstance(file_info, list): | |
file_path = file_info[0].get('datapath') | |
if file_path: | |
return {'src': file_path} | |
else: | |
file_path = file_info.get('datapath') | |
if file_path: | |
return {'src': file_path} | |
return None | |
def segmented_image(): | |
file_info = input.image_upload() | |
if file_info: | |
try: | |
file_path = file_info[0].get('datapath') if isinstance(file_info, list) else file_info.get('datapath') | |
if file_path: | |
segmented_path = process_image(file_path) | |
return {'src': segmented_path} | |
except Exception as e: | |
print(f"Error processing image: {e}") | |
return None | |
# Create and run the Shiny app | |
app = App(app_ui, server) | |
app.run(port=7860) |