Magix / main.py
Singularity666's picture
Update main.py
7f5929e verified
import os
import shutil
import json
import torch
import random
from pathlib import Path
from torch.utils.data import Dataset
from torchvision import transforms
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator
from tqdm.auto import tqdm
from PIL import Image
class CustomDataset(Dataset):
def __init__(self, data_dir, prompt, tokenizer, size=512, center_crop=False):
self.data_dir = Path(data_dir)
self.prompt = prompt
self.tokenizer = tokenizer
self.size = size
self.center_crop = center_crop
self.image_transforms = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
self.images = [f for f in self.data_dir.iterdir() if f.is_file() and not str(f).endswith(".txt")]
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image_path = self.images[idx]
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
image = self.image_transforms(image)
prompt_ids = self.tokenizer(
self.prompt, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length
).input_ids
return {"image": image, "prompt_ids": prompt_ids}
def fine_tune_model(instance_data_dir, instance_prompt, model_name, output_dir, seed=1337, resolution=512, train_batch_size=1, max_train_steps=800):
# Setup
accelerator = Accelerator()
set_seed(seed)
tokenizer = CLIPTokenizer.from_pretrained(model_name)
text_encoder = CLIPTextModel.from_pretrained(model_name)
vae = AutoencoderKL.from_pretrained(model_name)
unet = UNet2DConditionModel.from_pretrained(model_name)
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
dataset = CustomDataset(instance_data_dir, instance_prompt, tokenizer, resolution)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-6)
unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)
vae.to(accelerator.device)
text_encoder.to(accelerator.device)
global_step = 0
for step, batch in tqdm(enumerate(dataloader), total=max_train_steps):
latents = vae.encode(batch["image"].to(accelerator.device)).latent_dist.sample() * 0.18215
noise = torch.randn_like(latents)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
encoder_hidden_states = text_encoder(batch["prompt_ids"].to(accelerator.device))[0]
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
global_step += 1
if global_step >= max_train_steps:
break
# Save model
unet = accelerator.unwrap_model(unet)
unet.save_pretrained(output_dir)
vae.save_pretrained(output_dir)
text_encoder.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
def set_seed(seed):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)