import torch import torch.nn as nn from torchvision import transforms from safetensors.torch import load_file import numpy as np from PIL import Image import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec from model import U2Net if torch.cuda.is_available(): device = 'cuda' elif torch.backends.mps.is_available(): device = 'mps' else: device = 'cpu' device = torch.device(device) def preprocess_image(image_path): img = Image.open(image_path).convert('RGB') preprocess = transforms.Compose([ transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img = preprocess(img).unsqueeze(0).to(device) return img def run_inference(model, image_path, threshold=None): input_img = preprocess_image(image_path) with torch.no_grad(): d1, *_ = model(input_img) pred = torch.sigmoid(d1) pred = pred[0, :, :].cpu().numpy() pred = (pred - pred.min()) / (pred.max() - pred.min()) if threshold is not None: pred = (pred > threshold).astype(np.uint8) * 255 else: pred = (pred * 255).astype(np.uint8) return pred def overlay_segmentation(original_image, binary_mask, alpha=0.5): original_image = Image.open(original_image).convert('RGB').resize((512, 512), Image.BILINEAR) original_image_np = np.array(original_image) overlay = np.zeros_like(original_image_np) overlay[:, :, 0] = binary_mask overlay_image = (1 - alpha) * original_image_np + alpha * overlay overlay_image = overlay_image.astype(np.uint8) return overlay_image if __name__ == '__main__': # --- model_path = '../testing/u2net-duts-msra.safetensors' filename = input('Filename: ') image_path = f'../content_images/{filename}' # --- model = U2Net().to(device) model = nn.DataParallel(model) model.load_state_dict(load_file(model_path, device=device.type)) model.eval() mask = run_inference(model, image_path) mask_with_threshold = run_inference(model, image_path, threshold=0.7) fig = plt.figure(figsize=(10, 10)) gs = GridSpec(2, 2, figure=fig, wspace=0, hspace=0) images = [ Image.open(image_path).resize((512, 512)), mask, overlay_segmentation(image_path, mask_with_threshold), mask_with_threshold ] for i, img in enumerate(images): ax = fig.add_subplot(gs[i // 2, i % 2]) ax.imshow(img, cmap='gray' if i % 2 != 0 else None) ax.axis('off') plt.subplots_adjust(left=0, right=1, top=1, bottom=0) plt.savefig('../testing/inference-output.jpg', format='jpg', bbox_inches='tight', pad_inches=0)