edemana's picture
Upload folder using huggingface_hub
da0313f verified
import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
from tqdm import tqdm
def optimize_latent_vector(G, target_image, num_iterations=1000):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
target_image = transforms.Resize((G.img_resolution, G.img_resolution))(target_image)
target_tensor = transforms.ToTensor()(target_image).unsqueeze(0).to(device)
target_tensor = (target_tensor * 2) - 1 # Normalize to [-1, 1]
latent_vector = torch.randn((1, G.z_dim), device=device, requires_grad=True)
optimizer = torch.optim.Adam([latent_vector], lr=0.1)
for i in tqdm(range(num_iterations), desc="Optimizing latent vector"):
optimizer.zero_grad()
generated_image = G(latent_vector, None)
loss = torch.nn.functional.mse_loss(generated_image, target_tensor)
loss.backward()
optimizer.step()
if (i + 1) % 100 == 0:
print(f'Iteration {i+1}/{num_iterations}, Loss: {loss.item()}')
return latent_vector.detach()
def generate_from_upload(uploaded_image):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Optimize latent vector for the uploaded image
optimized_z = optimize_latent_vector(G, uploaded_image)
# Generate variations
num_variations = 4
variation_strength = 0.1
varied_z = optimized_z + torch.randn((num_variations, G.z_dim), device=device) * variation_strength
# Generate the variations
with torch.no_grad():
imgs = G(varied_z, c=None, truncation_psi=0.7, noise_mode='const')
imgs = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8)
imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
# Convert the generated image tensors to PIL Images
generated_images = [Image.fromarray(img) for img in imgs]
# Return the images separately
return generated_images[0], generated_images[1], generated_images[2], generated_images[3]
# Create the Gradio interface
iface = gr.Interface(
fn=generate_from_upload,
inputs=gr.Image(type="pil"),
outputs=[gr.Image(type="pil") for _ in range(4)],
title="StyleGAN Image Variation Generator"
)
# Launch the Gradio interface
iface.launch(share=True, debug=True)