import torch from torch.utils.data import DataLoader, Dataset from torchvision import transforms from PIL import Image from diffusers import StableDiffusionPipeline from transformers import CLIPTokenizer import os import zipfile import gradio as gr # Define the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define your custom dataset class CustomImageDataset(Dataset): def __init__(self, images, prompts, transform=None): self.images = images self.prompts = prompts self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] if self.transform: image = self.transform(image) prompt = self.prompts[idx] return image, prompt # Function to fine-tune the model def fine_tune_model(images, prompts, model_save_path, num_epochs=3): transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = CustomImageDataset(images, prompts, transform) dataloader = DataLoader(dataset, batch_size=4, shuffle=True) # Load Stable Diffusion model pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device) # Load model components vae = pipeline.vae.to(device) unet = pipeline.unet.to(device) text_encoder = pipeline.text_encoder.to(device) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") # Ensure correct tokenizer is used optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6) # Define the optimizer # Define timestep range for training timesteps = torch.linspace(0, 1, steps=5).to(device) # Fine-tuning loop for epoch in range(num_epochs): for i, (images, prompts) in enumerate(dataloader): images = images.to(device) # Move images to GPU if available # Tokenize the prompts inputs = tokenizer(list(prompts), padding=True, return_tensors="pt", truncation=True).to(device) latents = vae.encode(images).latent_dist.sample() * 0.18215 text_embeddings = text_encoder(inputs.input_ids).last_hidden_state noise = torch.randn_like(latents).to(device) noisy_latents = latents + noise # Pass text embeddings and timestep to UNet timestep = torch.randint(0, len(timesteps), (latents.size(0),), device=device).float() pred_noise = unet(noisy_latents, timestep=timestep, encoder_hidden_states=text_embeddings).sample loss = torch.nn.functional.mse_loss(pred_noise, noise) optimizer.zero_grad() loss.backward() optimizer.step() # Save the fine-tuned model pipeline.save_pretrained(model_save_path) # Function to convert tensor to PIL Image def tensor_to_pil(tensor): tensor = tensor.squeeze().cpu().clamp(0, 1) # Remove batch dimension if necessary tensor = transforms.ToPILImage()(tensor) return tensor # Function to generate images def generate_images(pipeline, prompt): with torch.no_grad(): # Generate image from the prompt output = pipeline(prompt) # Convert the output to PIL Image image = output.images[0] # Get the first generated image return image # Function to zip the fine-tuned model def zip_model(model_path): zip_path = f"{model_path}.zip" with zipfile.ZipFile(zip_path, "w") as zipf: for root, _, files in os.walk(model_path): for file in files: zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path)) return zip_path # Function to save uploaded files def save_uploaded_file(uploaded_file, save_path): # Open the file in binary write mode with open(save_path, 'wb') as f: f.write(uploaded_file.data) # Use .data for the file content return f"File saved at {save_path}" # Gradio interface functions def start_fine_tuning(uploaded_files, prompts, num_epochs): images = [Image.open(file).convert("RGB") for file in uploaded_files] model_save_path = "fine_tuned_model" fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs)) return "Fine-tuning completed! Model is ready for download." def download_model(): model_save_path = "fine_tuned_model" if os.path.exists(model_save_path): return zip_model(model_save_path) else: return None def generate_new_image(prompt): model_save_path = "fine_tuned_model" if os.path.exists(model_save_path): pipeline = StableDiffusionPipeline.from_pretrained(model_save_path).to(device) else: pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device) image = generate_images(pipeline, prompt) image_path = "generated_image.png" image.save(image_path) return image_path # Gradio interface with gr.Blocks() as demo: gr.Markdown("# Fine-Tune Stable Diffusion and Generate Images") with gr.Tab("Fine-Tune Model"): with gr.Row(): uploaded_files = gr.File(label="Upload Images", file_types=[".png", ".jpg", ".jpeg"], file_count="multiple") with gr.Row(): prompts = gr.Textbox(label="Enter Prompts (comma-separated)") num_epochs = gr.Number(label="Number of Epochs", value=3) with gr.Row(): fine_tune_button = gr.Button("Start Fine-Tuning") fine_tune_output = gr.Textbox(label="Output") fine_tune_button.click(start_fine_tuning, [uploaded_files, prompts, num_epochs], fine_tune_output) with gr.Tab("Download Fine-Tuned Model"): download_button = gr.Button("Download Fine-Tuned Model") download_output = gr.File() download_button.click(download_model, [], download_output) with gr.Tab("Generate New Images"): prompt_input = gr.Textbox(label="Enter a Prompt") generate_button = gr.Button("Generate Image") generated_image = gr.Image(label="Generated Image") generate_button.click(generate_new_image, [prompt_input], generated_image) demo.launch()