Spaces:
Runtime error
Runtime error
import torch | |
from PIL import Image | |
import torchvision.transforms as T | |
import matplotlib.pyplot as plt | |
from model import CycleGAN | |
# Load and preprocess the input image | |
def load_image(image_path, device, image_size=(256, 256)): | |
transform = T.Compose([ | |
T.Resize(image_size), | |
T.ToTensor(), | |
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1] | |
]) | |
image = Image.open(image_path).convert("RGB") | |
image = transform(image).unsqueeze(0).to(device) | |
return image | |
# Display the output image | |
def display_image(tensor_image): | |
tensor_image = tensor_image.squeeze(0).cpu() # Remove batch dimension | |
tensor_image = (tensor_image * 0.5 + 0.5).clamp(0, 1) # Denormalize | |
plt.imshow(tensor_image.permute(1, 2, 0)) # CHW to HWC | |
plt.axis("off") | |
plt.show() | |
# Load the input image | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = CycleGAN.load_from_checkpoint("/content/cyclegan_monet_unet_250_epochs.ckpt", **MODEL_CONFIG) | |