Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from matplotlib.gridspec import GridSpec | |
from model import U2Net | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def preprocess_image(image_path): | |
img = Image.open(image_path).convert('RGB') | |
preprocess = transforms.Compose([ | |
transforms.Resize((512, 512)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
img = preprocess(img).unsqueeze(0).to(device) | |
return img | |
def run_inference(model, image_path, threshold=0.5): | |
input_img = preprocess_image(image_path) | |
with torch.no_grad(): | |
d1, *_ = model(input_img) | |
pred = torch.sigmoid(d1) | |
pred = pred[0, :, :].cpu().numpy() | |
pred = (pred - pred.min()) / (pred.max() - pred.min()) | |
if threshold is not None: | |
pred = (pred > threshold).astype(np.uint8) * 255 | |
else: | |
pred = (pred * 255).astype(np.uint8) | |
return pred | |
def overlay_segmentation(original_image, binary_mask, alpha=0.5): | |
original_image = Image.open(original_image).convert('RGB').resize((512, 512), Image.BILINEAR) | |
original_image_np = np.array(original_image) | |
overlay = np.zeros_like(original_image_np) | |
overlay[:, :, 0] = binary_mask | |
overlay_image = (1 - alpha) * original_image_np + alpha * overlay | |
overlay_image = overlay_image.astype(np.uint8) | |
return overlay_image | |
if __name__ == '__main__': | |
# --- | |
model_path = 'results/inter-u2net-duts.pt' | |
image_path = 'images/ladies.jpg' | |
# --- | |
model = U2Net().to(device) | |
model = nn.DataParallel(model) | |
state_dict = torch.load(model_path, map_location=device, weights_only=True) | |
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True)) | |
model.eval() | |
mask = run_inference(model, image_path, threshold=None) | |
mask_with_threshold = run_inference(model, image_path, threshold=0.7) | |
fig = plt.figure(figsize=(10, 10)) | |
gs = GridSpec(2, 2, figure=fig, wspace=0, hspace=0) | |
images = [ | |
Image.open(image_path).resize((512, 512)), | |
mask, | |
overlay_segmentation(image_path, mask_with_threshold), | |
mask_with_threshold | |
] | |
for i, img in enumerate(images): | |
ax = fig.add_subplot(gs[i // 2, i % 2]) | |
ax.imshow(img, cmap='gray' if i % 2 != 0 else None) | |
ax.axis('off') | |
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) | |
plt.savefig('inference-output.jpg', format='jpg', bbox_inches='tight', pad_inches=0) | |