StyleTransferDemo / utils.py
batmangiaicuuthegioi's picture
Upload 7 files
dca2470 verified
import torch
from PIL import Image
import torchvision.transforms as T
import matplotlib.pyplot as plt
from model import CycleGAN
# Load and preprocess the input image
def load_image(image_path, device, image_size=(256, 256)):
transform = T.Compose([
T.Resize(image_size),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
])
image = Image.open(image_path).convert("RGB")
image = transform(image).unsqueeze(0).to(device)
return image
# Display the output image
def display_image(tensor_image):
tensor_image = tensor_image.squeeze(0).cpu() # Remove batch dimension
tensor_image = (tensor_image * 0.5 + 0.5).clamp(0, 1) # Denormalize
plt.imshow(tensor_image.permute(1, 2, 0)) # CHW to HWC
plt.axis("off")
plt.show()
# Load the input image
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CycleGAN.load_from_checkpoint("/content/cyclegan_monet_unet_250_epochs.ckpt", **MODEL_CONFIG)