Spaces:
Sleeping
Sleeping
File size: 2,959 Bytes
90baaac f157bf0 90baaac 3dd227f 90baaac 3dd227f 90baaac 3dd227f 90baaac 3dd227f 90baaac 3dd227f 90baaac f157bf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
from shiny import App, ui, render, reactive
import os
import numpy as np
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
# Load the processor and the finetuned model
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model_path = "mito_model_checkpoint.pth"
model = SamModel.from_pretrained("facebook/sam-vit-base")
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
def process_image(image_path):
# Open and prepare the image
image = Image.open(image_path).convert("RGB") # Ensure RGB format for consistency
image_np = np.array(image)
# Prepare the image for the model using the processor
inputs = processor(images=image_np, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Perform inference
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
# Process the prediction to create a binary mask
pred_masks = torch.sigmoid(outputs.pred_masks).cpu().numpy()
segmented_image = (pred_masks[0] > .99).astype(np.uint8) * 255
print(segmented_image)
# Save the segmented image
root, ext = os.path.splitext(image_path)
output_path = f"{root}_segmented.png"
segmented_image_pil = Image.fromarray(segmented_image.squeeze(), mode="L")
segmented_image_pil.save(output_path)
return output_path
# Define the Shiny app UI layout
app_ui = ui.page_fluid(
ui.layout_sidebar(
ui.panel_sidebar(
ui.input_file("image_upload", "Upload Satellite Image", accept=".jpg,.jpeg,.png,.tif")
),
ui.panel_main(
ui.output_image("uploaded_image", "Uploaded Image"),
ui.output_image("segmented_image", "Segmented Image")
)
)
)
def server(input, output, session):
@output
@render.image
def uploaded_image():
file_info = input.image_upload()
if file_info:
if isinstance(file_info, list):
file_path = file_info[0].get('datapath')
if file_path:
return {'src': file_path}
else:
file_path = file_info.get('datapath')
if file_path:
return {'src': file_path}
return None
@output
@render.image
def segmented_image():
file_info = input.image_upload()
if file_info:
try:
file_path = file_info[0].get('datapath') if isinstance(file_info, list) else file_info.get('datapath')
if file_path:
segmented_path = process_image(file_path)
return {'src': segmented_path}
except Exception as e:
print(f"Error processing image: {e}")
return None
# Create and run the Shiny app
app = App(app_ui, server)
app.run(port=7860) |