File size: 4,046 Bytes
b212cf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
import torch.nn as nn
from torchvision.transforms import ToTensor, ToPILImage
from PIL import Image
class SobelOperator(nn.Module):
SOBEL_KERNEL_X = torch.tensor(
[[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]
)
SOBEL_KERNEL_Y = torch.tensor(
[[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]
)
def __init__(self, device="cuda"):
super(SobelOperator, self).__init__()
self.device = device
self.edge_conv_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
self.device
)
self.edge_conv_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
self.device
)
self.edge_conv_x.weight = nn.Parameter(
self.SOBEL_KERNEL_X.view((1, 1, 3, 3)).to(self.device)
)
self.edge_conv_y.weight = nn.Parameter(
self.SOBEL_KERNEL_Y.view((1, 1, 3, 3)).to(self.device)
)
@torch.no_grad()
def forward(
self,
image: Image.Image,
low_threshold: float,
high_threshold: float,
output_type="pil",
) -> Image.Image | torch.Tensor | tuple[Image.Image, torch.Tensor]:
# Convert PIL image to PyTorch tensor
image_gray = image.convert("L")
image_tensor = ToTensor()(image_gray).unsqueeze(0).to(self.device)
# Compute gradients
edge_x = self.edge_conv_x(image_tensor)
edge_y = self.edge_conv_y(image_tensor)
edge = torch.sqrt(torch.square(edge_x) + torch.square(edge_y))
# Apply thresholding
edge.div_(edge.max()) # Normalize to 0-1 (in-place operation)
edge[edge >= high_threshold] = 1.0
edge[edge <= low_threshold] = 0.0
# Convert the result back to a PIL image
if output_type == "pil":
return ToPILImage()(edge.squeeze(0).cpu())
elif output_type == "tensor":
return edge
elif output_type == "pil,tensor":
return ToPILImage()(edge.squeeze(0).cpu()), edge
class ScharrOperator(nn.Module):
SCHARR_KERNEL_X = torch.tensor(
[[-3.0, 0.0, 3.0], [-10.0, 0.0, 10.0], [-3.0, 0.0, 3.0]]
)
SCHARR_KERNEL_Y = torch.tensor(
[[-3.0, -10.0, -3.0], [0.0, 0.0, 0.0], [3.0, 10.0, 3.0]]
)
def __init__(self, device="cuda"):
super(ScharrOperator, self).__init__()
self.device = device
self.edge_conv_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
self.device
)
self.edge_conv_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
self.device
)
self.edge_conv_x.weight = nn.Parameter(
self.SCHARR_KERNEL_X.view((1, 1, 3, 3)).to(self.device)
)
self.edge_conv_y.weight = nn.Parameter(
self.SCHARR_KERNEL_Y.view((1, 1, 3, 3)).to(self.device)
)
@torch.no_grad()
def forward(
self,
image: Image.Image,
low_threshold: float,
high_threshold: float,
output_type="pil",
invert: bool = False,
) -> Image.Image | torch.Tensor | tuple[Image.Image, torch.Tensor]:
# Convert PIL image to PyTorch tensor
image_gray = image.convert("L")
image_tensor = ToTensor()(image_gray).unsqueeze(0).to(self.device)
# Compute gradients
edge_x = self.edge_conv_x(image_tensor)
edge_y = self.edge_conv_y(image_tensor)
edge = torch.abs(edge_x) + torch.abs(edge_y)
# Apply thresholding
edge.div_(edge.max()) # Normalize to 0-1 (in-place operation)
edge[edge >= high_threshold] = 1.0
edge[edge <= low_threshold] = 0.0
if invert:
edge = 1 - edge
# Convert the result back to a PIL image
if output_type == "pil":
return ToPILImage()(edge.squeeze(0).cpu())
elif output_type == "tensor":
return edge
elif output_type == "pil,tensor":
return ToPILImage()(edge.squeeze(0).cpu()), edge
|