import os import shutil import json import torch import random from pathlib import Path from torch.utils.data import Dataset from torchvision import transforms from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel, AutoencoderKL, DDPMScheduler from transformers import CLIPTextModel, CLIPTokenizer from accelerate import Accelerator from tqdm.auto import tqdm from PIL import Image class CustomDataset(Dataset): def __init__(self, data_dir, prompt, tokenizer, size=512, center_crop=False): self.data_dir = Path(data_dir) self.prompt = prompt self.tokenizer = tokenizer self.size = size self.center_crop = center_crop self.image_transforms = transforms.Compose([ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) self.images = [f for f in self.data_dir.iterdir() if f.is_file() and not str(f).endswith(".txt")] def __len__(self): return len(self.images) def __getitem__(self, idx): image_path = self.images[idx] image = Image.open(image_path) if not image.mode == "RGB": image = image.convert("RGB") image = self.image_transforms(image) prompt_ids = self.tokenizer( self.prompt, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length ).input_ids return {"image": image, "prompt_ids": prompt_ids} def fine_tune_model(instance_data_dir, instance_prompt, model_name, output_dir, seed=1337, resolution=512, train_batch_size=1, max_train_steps=800): # Setup accelerator = Accelerator() set_seed(seed) tokenizer = CLIPTokenizer.from_pretrained(model_name) text_encoder = CLIPTextModel.from_pretrained(model_name) vae = AutoencoderKL.from_pretrained(model_name) unet = UNet2DConditionModel.from_pretrained(model_name) noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler") dataset = CustomDataset(instance_data_dir, instance_prompt, tokenizer, resolution) dataloader = torch.utils.data.DataLoader(dataset, batch_size=train_batch_size, shuffle=True) optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-6) unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader) vae.to(accelerator.device) text_encoder.to(accelerator.device) global_step = 0 for step, batch in tqdm(enumerate(dataloader), total=max_train_steps): latents = vae.encode(batch["image"].to(accelerator.device)).latent_dist.sample() * 0.18215 noise = torch.randn_like(latents) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) encoder_hidden_states = text_encoder(batch["prompt_ids"].to(accelerator.device))[0] model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean") accelerator.backward(loss) optimizer.step() optimizer.zero_grad() global_step += 1 if global_step >= max_train_steps: break # Save model unet = accelerator.unwrap_model(unet) unet.save_pretrained(output_dir) vae.save_pretrained(output_dir) text_encoder.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) def set_seed(seed): random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)