from flask import Flask import gradio as gr import torch from torchvision import transforms from PIL import Image from config import MODEL_CONFIG from model import CycleGAN # Load the CycleGAN models model_paths = { "CycleGAN_Cezanne_Unet_300": "/checkpoints/checkpoints/cyclegan_cezanne_unet_300_epochs.ckpt", "CycleGAN_Monet_Unet_250": "/checkpoints/checkpoints/cyclegan_monet_unet_250_epochs.ckpt", "CycleGAN_Vangogh_Resnet_70": "/cyclegan_vangogh_resnet_70_epochs.ckpt", "CycleGAN_Vangogh_Unet_70":"/cyclegan_vangogh_unet_70_epochs.ckpt" } models = {name: CycleGAN.load_from_checkpoint(path, **MODEL_CONFIG) for name, path in model_paths.items()} # Define the image transformation transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) # Define the image translation function def translate_image(input_image, style): model = models[style] image = transform(input_image).unsqueeze(0) with torch.no_grad(): translated_image = model(image) return transforms.ToPILImage()(translated_image.squeeze(0)) # Initialize the Gradio interface iface = gr.Interface( fn=translate_image, inputs=[ gr.Image(type="pil"), gr.Dropdown(choices=list(models.keys()), label="Select Style") ], outputs=gr.Image(type="pil"), title="CycleGAN Image Translation", description="Upload an image and select a style to translate it using CycleGAN." ) if __name__ == "__main__": iface.launch(debug=True)