jamino30's picture
Upload folder using huggingface_hub
5464cad verified
raw
history blame
2.65 kB
import torch
import torch.nn as nn
from torchvision import transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from model import U2Net
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def preprocess_image(image_path):
img = Image.open(image_path).convert('RGB')
preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = preprocess(img).unsqueeze(0).to(device)
return img
def run_inference(model, image_path, threshold=0.5):
input_img = preprocess_image(image_path)
with torch.no_grad():
d1, *_ = model(input_img)
pred = torch.sigmoid(d1)
pred = pred[0, :, :].cpu().numpy()
pred = (pred - pred.min()) / (pred.max() - pred.min())
if threshold is not None:
pred = (pred > threshold).astype(np.uint8) * 255
else:
pred = (pred * 255).astype(np.uint8)
return pred
def overlay_segmentation(original_image, binary_mask, alpha=0.5):
original_image = Image.open(original_image).convert('RGB').resize((512, 512), Image.BILINEAR)
original_image_np = np.array(original_image)
overlay = np.zeros_like(original_image_np)
overlay[:, :, 0] = binary_mask
overlay_image = (1 - alpha) * original_image_np + alpha * overlay
overlay_image = overlay_image.astype(np.uint8)
return overlay_image
if __name__ == '__main__':
# ---
model_path = 'results/inter-u2net-duts.pt'
image_path = 'images/ladies.jpg'
# ---
model = U2Net().to(device)
model = nn.DataParallel(model)
state_dict = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
model.eval()
mask = run_inference(model, image_path, threshold=None)
mask_with_threshold = run_inference(model, image_path, threshold=0.7)
fig = plt.figure(figsize=(10, 10))
gs = GridSpec(2, 2, figure=fig, wspace=0, hspace=0)
images = [
Image.open(image_path).resize((512, 512)),
mask,
overlay_segmentation(image_path, mask_with_threshold),
mask_with_threshold
]
for i, img in enumerate(images):
ax = fig.add_subplot(gs[i // 2, i % 2])
ax.imshow(img, cmap='gray' if i % 2 != 0 else None)
ax.axis('off')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.savefig('inference-output.jpg', format='jpg', bbox_inches='tight', pad_inches=0)