import torch import torchvision.transforms as transforms from PIL import Image import gradio as gr from tqdm import tqdm def optimize_latent_vector(G, target_image, num_iterations=1000): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') target_image = transforms.Resize((G.img_resolution, G.img_resolution))(target_image) target_tensor = transforms.ToTensor()(target_image).unsqueeze(0).to(device) target_tensor = (target_tensor * 2) - 1 # Normalize to [-1, 1] latent_vector = torch.randn((1, G.z_dim), device=device, requires_grad=True) optimizer = torch.optim.Adam([latent_vector], lr=0.1) for i in tqdm(range(num_iterations), desc="Optimizing latent vector"): optimizer.zero_grad() generated_image = G(latent_vector, None) loss = torch.nn.functional.mse_loss(generated_image, target_tensor) loss.backward() optimizer.step() if (i + 1) % 100 == 0: print(f'Iteration {i+1}/{num_iterations}, Loss: {loss.item()}') return latent_vector.detach() def generate_from_upload(uploaded_image): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Optimize latent vector for the uploaded image optimized_z = optimize_latent_vector(G, uploaded_image) # Generate variations num_variations = 4 variation_strength = 0.1 varied_z = optimized_z + torch.randn((num_variations, G.z_dim), device=device) * variation_strength # Generate the variations with torch.no_grad(): imgs = G(varied_z, c=None, truncation_psi=0.7, noise_mode='const') imgs = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8) imgs = imgs.permute(0, 2, 3, 1).cpu().numpy() # Convert the generated image tensors to PIL Images generated_images = [Image.fromarray(img) for img in imgs] # Return the images separately return generated_images[0], generated_images[1], generated_images[2], generated_images[3] # Create the Gradio interface iface = gr.Interface( fn=generate_from_upload, inputs=gr.Image(type="pil"), outputs=[gr.Image(type="pil") for _ in range(4)], title="StyleGAN Image Variation Generator" ) # Launch the Gradio interface iface.launch(share=True, debug=True)