Spaces:
Build error
Build error
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import gradio as gr | |
from tqdm import tqdm | |
def optimize_latent_vector(G, target_image, num_iterations=1000): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
target_image = transforms.Resize((G.img_resolution, G.img_resolution))(target_image) | |
target_tensor = transforms.ToTensor()(target_image).unsqueeze(0).to(device) | |
target_tensor = (target_tensor * 2) - 1 # Normalize to [-1, 1] | |
latent_vector = torch.randn((1, G.z_dim), device=device, requires_grad=True) | |
optimizer = torch.optim.Adam([latent_vector], lr=0.1) | |
for i in tqdm(range(num_iterations), desc="Optimizing latent vector"): | |
optimizer.zero_grad() | |
generated_image = G(latent_vector, None) | |
loss = torch.nn.functional.mse_loss(generated_image, target_tensor) | |
loss.backward() | |
optimizer.step() | |
if (i + 1) % 100 == 0: | |
print(f'Iteration {i+1}/{num_iterations}, Loss: {loss.item()}') | |
return latent_vector.detach() | |
def generate_from_upload(uploaded_image): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Optimize latent vector for the uploaded image | |
optimized_z = optimize_latent_vector(G, uploaded_image) | |
# Generate variations | |
num_variations = 4 | |
variation_strength = 0.1 | |
varied_z = optimized_z + torch.randn((num_variations, G.z_dim), device=device) * variation_strength | |
# Generate the variations | |
with torch.no_grad(): | |
imgs = G(varied_z, c=None, truncation_psi=0.7, noise_mode='const') | |
imgs = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
imgs = imgs.permute(0, 2, 3, 1).cpu().numpy() | |
# Convert the generated image tensors to PIL Images | |
generated_images = [Image.fromarray(img) for img in imgs] | |
# Return the images separately | |
return generated_images[0], generated_images[1], generated_images[2], generated_images[3] | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=generate_from_upload, | |
inputs=gr.Image(type="pil"), | |
outputs=[gr.Image(type="pil") for _ in range(4)], | |
title="StyleGAN Image Variation Generator" | |
) | |
# Launch the Gradio interface | |
iface.launch(share=True, debug=True) | |