Spaces:
Runtime error
Runtime error
File size: 2,909 Bytes
ab52a15 8b61b70 ab52a15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import tempfile
from PIL import Image
from pathlib import Path
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
from shiny.types import FileInfo
import torch
import numpy as np
import os
from transformers import SamModel
import torchvision.transforms as transforms
image_resize_transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor()
])
app_ui = ui.page_fluid(
ui.input_file("file2", "Choose Image", accept=".jpg, .jpeg, .png, .tiff, .tif", multiple=False),
ui.output_image("original_image"),
ui.output_image("image_display")
)
def server(input: Inputs, output: Outputs, session: Session):
@reactive.calc
def loaded_image():
file: list[FileInfo] | None = input.file2()
if file is None:
return None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model2 = SamModel.from_pretrained("facebook/sam-vit-base")
model2.load_state_dict(torch.load('model.pth', map_location=device))
model2.eval()
model2.to(device)
image = Image.open(file[0]["datapath"]).convert('RGB')
transform = image_resize_transform
image_tensor = transform(image).to(device)
with torch.no_grad():
outputs = model2(pixel_values=image_tensor.unsqueeze(0),multimask_output=False)
predicted_masks = outputs.pred_masks.squeeze(1)
predicted_masks = predicted_masks[:, 0, :, :]
mask_tensor = predicted_masks.cpu().detach().squeeze()
mask_array = mask_tensor.numpy()
mask_array = (mask_array * 255).astype(np.uint8)
mask = Image.fromarray(mask_array)
mask = mask.resize((1024, 1024), Image.LANCZOS)
mask = mask.convert('RGBA')
alpha = Image.new('L', mask.size, 128)
mask.putalpha(alpha)
image = Image.open(file[0]["datapath"]).convert('RGB')
image = image.resize((1024, 1024), Image.LANCZOS)
image = image.convert('RGBA')
combined = Image.alpha_composite(image, mask)
combined_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
original_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
image.save(original_file.name, "PNG", quality=100)
mask.save(combined_file.name, "PNG", quality=100)
return original_file.name, combined_file.name
@render.image
def original_image():
result = loaded_image()
if result is None:
return None
img_path, _ = result
return {"src": img_path, "width": "300px"}
@render.image
def image_display():
result = loaded_image()
if result is None:
return None
_, img_path = result
return {"src": img_path, "width": "300px"}
app = App(app_ui, server)
|