import warnings warnings.filterwarnings("ignore") import argparse import json import logging from pathlib import Path from typing import Dict, List import time import matplotlib.pyplot as plt import numpy as np import datetime import math import torch from torchvision import transforms from PIL import Image from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDIMScheduler, DDPMScheduler from transformers import CLIPTextModel, CLIPTokenizer from transformers import get_scheduler from tqdm.auto import tqdm import torch.nn.functional as F from safetensors.torch import save_file # Configure logging to be less verbose logging.basicConfig( level=logging.INFO, format='%(message)s', handlers=[logging.StreamHandler()] ) logger = logging.getLogger(__name__) # Disable specific warnings import transformers transformers.logging.set_verbosity_error() import diffusers diffusers.logging.set_verbosity_error() class TrainingMonitor: def __init__(self, output_dir: str): self.output_dir = Path(output_dir) self.metrics = { 'loss': [], 'learning_rate': [], 'epoch': [], 'step': [] } self.start_time = time.time() def log_step(self, loss: float, lr: float, epoch: int, step: int): self.metrics['loss'].append(loss) self.metrics['learning_rate'].append(lr) self.metrics['epoch'].append(epoch) self.metrics['step'].append(step) def save_metrics(self): metrics_file = self.output_dir / 'training_metrics.json' with open(metrics_file, 'w') as f: json.dump(self.metrics, f) self._plot_curves() def _plot_curves(self): plt.figure(figsize=(12, 8)) plt.subplot(2, 1, 1) plt.plot(self.metrics['step'], self.metrics['loss']) plt.title('Training Loss') plt.xlabel('Step') plt.ylabel('Loss') plt.subplot(2, 1, 2) plt.plot(self.metrics['step'], self.metrics['learning_rate']) plt.title('Learning Rate') plt.xlabel('Step') plt.ylabel('Learning Rate') plt.tight_layout() plt.savefig(self.output_dir / 'training_curves.png') plt.close() class KanjiDataset(torch.utils.data.Dataset): def __init__( self, data_dir: str, tokenizer: CLIPTokenizer, size: int = 256, center_crop: bool = True ): self.data_dir = Path(data_dir) self.tokenizer = tokenizer # Load dataset metadata with open(self.data_dir / "metadata/dataset.json", "r", encoding="utf-8") as f: self.metadata = json.load(f) self.transform = transforms.Compose([ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size), # Always center crop for consistency transforms.Grayscale(num_output_channels=3), # Convert to grayscale but keep 3 channels transforms.ToTensor(), transforms.Lambda(lambda x: torch.where(x > 0.5, 1.0, 0.0)), # Threshold to pure black/white transforms.Normalize([0.5], [0.5]) ]) logger.info(f"Loaded dataset with {len(self.metadata)} examples") def __len__(self): return len(self.metadata) def __getitem__(self, idx): item = self.metadata[idx] image_path = self.data_dir / item["image_path"] # Load and transform image image = Image.open(image_path).convert("RGB") image = self.transform(image) # Combine all meanings into a single prompt prompt = ", ".join(item["meanings"]) # Tokenize text encoding = self.tokenizer( prompt, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, return_tensors="pt" ) return { "pixel_values": image.to(torch.float32), # Keep in float32 "input_ids": encoding.input_ids[0] } def print_gpu_memory(): if torch.cuda.is_available(): print(f"") def train_model( data_dir: str, output_dir: str, pretrained_model_name: str = "CompVis/stable-diffusion-v1-4", train_batch_size: int = 4, gradient_accumulation_steps: int = 2, learning_rate: float = 2e-4, vae_learning_rate: float = 2e-5, max_train_steps: int = 6000, mixed_precision: str = "fp16", enable_gradient_checkpointing: bool = True, enable_mem_efficient_attention: bool = True, image_size: int = 256, project_name: str = "kanji-sd", run_name: str = None, enable_wandb: bool = False, warmup_steps: int = 100, max_grad_norm: float = 1.0, freeze_text_encoder: bool = True, freeze_vae: bool = False, ): output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Initialize accelerator with memory efficient options ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision=mixed_precision, kwargs_handlers=[ddp_kwargs], log_with="wandb" if enable_wandb else None ) # Initialize wandb if enabled if enable_wandb and accelerator.is_main_process: try: import wandb wandb.init( project=project_name, name=run_name or f"sd-train-{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}", config={ "pretrained_model": pretrained_model_name, "batch_size": train_batch_size, "grad_accum": gradient_accumulation_steps, "learning_rate": learning_rate, "max_steps": max_train_steps, "mixed_precision": mixed_precision, "image_size": image_size } ) except ImportError: logger.warning("wandb not installed. Skipping wandb logging.") enable_wandb = False # Initialize training monitor monitor = TrainingMonitor(output_dir) if accelerator.is_main_process: logger.info(f"Using mixed precision: {mixed_precision}") logger.info(f"Device: {accelerator.device}") # Memory cleanup helper def cleanup(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() if accelerator.is_main_process: logger.info("Loading models and tokenizer...") # Get device device = accelerator.device # Set model dtype based on mixed precision model_dtype = torch.float16 if mixed_precision == "fp16" else torch.float32 # Load tokenizer tokenizer = CLIPTokenizer.from_pretrained( pretrained_model_name, subfolder="tokenizer", use_safetensors=True ) # Load text encoder text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name, subfolder="text_encoder", use_safetensors=True, torch_dtype=model_dtype, low_cpu_mem_usage=True ).to(device) cleanup() # Clean after loading text encoder # Load VAE vae = AutoencoderKL.from_pretrained( pretrained_model_name, subfolder="vae", use_safetensors=True, torch_dtype=model_dtype, low_cpu_mem_usage=True ).to(device) cleanup() # Clean after loading VAE # Load UNet with memory optimizations unet = UNet2DConditionModel.from_pretrained( pretrained_model_name, subfolder="unet", cache_dir=None, use_memory_efficient_attention=enable_mem_efficient_attention, ) if enable_gradient_checkpointing: unet.enable_gradient_checkpointing() unet.to(device) # Enable memory efficient attention if hasattr(unet.config, "use_memory_efficient_attention"): unet.config.use_memory_efficient_attention = True if hasattr(unet.config, "use_sdp_attention"): unet.config.use_sdp_attention = True if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): logger.info("Using flash attention") torch.backends.cuda.enable_flash_sdp(True) # Freeze VAE and text encoder text_encoder.requires_grad_(False) if freeze_vae: vae.requires_grad_(False) # Enable training for ALL UNet parameters to aggressively learn kanji style logger.info("Enabling training for all UNet parameters to override SD knowledge") for param in unet.parameters(): param.requires_grad_(True) # Count trainable parameters trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad) logger.info(f"Aggressively fine-tuning UNet with {trainable_params:,} parameters") # Create dataset and dataloader dataset = KanjiDataset(data_dir, tokenizer, size=image_size) train_dataloader = torch.utils.data.DataLoader( dataset, batch_size=train_batch_size, shuffle=True, pin_memory=True, num_workers=0 ) # Prepare optimizers - separate for UNet and VAE try: import bitsandbytes as bnb optimizer_cls = bnb.optim.AdamW8bit if accelerator.is_main_process: logger.info("Using 8-bit Adam optimizer for memory efficiency") except ImportError: optimizer_cls = torch.optim.AdamW if accelerator.is_main_process: logger.warning("bitsandbytes not found, using regular AdamW. This may use more memory.") # Initialize UNet optimizer with conservative learning rate unet_optimizer = optimizer_cls( unet.parameters(), lr=1e-4, # Conservative base learning rate betas=(0.9, 0.999), weight_decay=2e-2, # Increased weight decay for stability eps=1e-8 ) # Initialize VAE optimizer vae_optimizer = optimizer_cls( vae.parameters(), lr=1e-5, # Maintain 1:10 ratio betas=(0.9, 0.999), weight_decay=2e-2, eps=1e-8 ) # Convert steps to float to avoid integer division warmup_steps_f = float(warmup_steps) max_steps_f = float(max_train_steps) if accelerator.is_main_process: logger.info(f"Initial UNet LR: {unet_optimizer.param_groups[0]['lr']}") logger.info(f"Initial VAE LR: {vae_optimizer.param_groups[0]['lr']}") logger.info(f"Warmup steps: {warmup_steps_f}") logger.info(f"Max steps: {max_steps_f}") # Define conservative learning rate schedule def create_lr_lambda(): def get_lr(step): step_f = float(step) if step_f < warmup_steps_f: # Very gentle warmup from 50% to 100% return 0.5 + 0.5 * (step_f / warmup_steps_f) else: # Gradual cosine decay from 100% to 50% progress = (step_f - warmup_steps_f) / (max_steps_f - warmup_steps_f) return 0.5 + 0.5 * (0.5 * (1.0 + math.cos(math.pi * progress))) return get_lr # Define VAE learning rate schedule def create_vae_lr_lambda(): def get_vae_lr(step): step_f = float(step) if step_f < warmup_steps_f: # Match UNet warmup pattern return 0.5 + 0.5 * (step_f / warmup_steps_f) else: # Match UNet decay pattern progress = (step_f - warmup_steps_f) / (max_steps_f - warmup_steps_f) return 0.5 + 0.5 * (0.5 * (1.0 + math.cos(math.pi * progress))) return get_vae_lr # Increase gradient clipping for stability max_grad_norm = 0.5 # Was 1.0 # Create learning rate schedulers lr_scheduler = torch.optim.lr_scheduler.LambdaLR(unet_optimizer, lr_lambda=create_lr_lambda()) vae_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(vae_optimizer, lr_lambda=create_vae_lr_lambda()) # Prepare for training unet, vae, unet_optimizer, vae_optimizer, train_dataloader, lr_scheduler, vae_lr_scheduler = accelerator.prepare( unet, vae, unet_optimizer, vae_optimizer, train_dataloader, lr_scheduler, vae_lr_scheduler ) # Calculate total training steps (one epoch) num_examples = len(train_dataloader.dataset) effective_batch_size = train_batch_size * gradient_accumulation_steps total_train_steps = num_examples // effective_batch_size if num_examples % effective_batch_size != 0: total_train_steps += 1 # Add one step for partial batch if accelerator.is_main_process: logger.info(f"Dataset size: {num_examples} examples") logger.info(f"Effective batch size: {effective_batch_size}") logger.info(f"Steps per epoch: {total_train_steps}") # Define checkpoint steps at fixed intervals checkpoint_steps = [1000, 3000, 5000] if accelerator.is_main_process: logger.info(f"Will save checkpoints at steps: {checkpoint_steps}") # Create noise scheduler optimized for binary images noise_scheduler = DDIMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=True, steps_offset=1, ) # Add gradient clipping max_grad_norm = max_grad_norm # Training loop global_step = 0 progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process) batch_step = 0 # Track the actual batch step last_saved_checkpoint = -1 # Track last saved checkpoint # Training loop for batch in train_dataloader: print_gpu_memory() print("Before processing batch") # Get input images and convert to latents pixel_values = batch["pixel_values"].to(device) print_gpu_memory() print("After loading images") # Convert images to pure black and white during training with torch.no_grad(): pixel_values = (pixel_values > 0.5).float() # Convert to latents first latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor print_gpu_memory() print("After VAE encode") # Sample noise with very low variance for ultra-sharp edges noise = torch.randn_like(latents) * 0.5 # Even less noise bsz = latents.shape[0] # Use fewer timesteps to focus on binary outputs max_timestep = min(400, int(400 * (1 - 0.7 * batch_step/max_train_steps))) # Start at 400, reduce to 120 timesteps = torch.randint(0, max_timestep, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Calculate UNet loss with extreme focus on binary outputs mse_loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]) l1_loss = F.l1_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]) # 90% L1 loss to force binary-like outputs unet_loss = (0.1 * mse_loss + 0.9 * l1_loss).mean() # Backpropagate UNet loss first to free memory accelerator.backward(unet_loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm) unet_optimizer.step() unet_optimizer.zero_grad() # Clear memory before VAE training del noise_pred, encoder_hidden_states, noisy_latents, latents torch.cuda.empty_cache() # Calculate VAE reconstruction loss with binary focus if not freeze_vae: # Process in chunks with gradient accumulation chunk_size = 1 # Process one image at a time num_chunks = pixel_values.shape[0] vae_optimizer.zero_grad() # Zero gradients before accumulation for i in range(num_chunks): # Get single image img = pixel_values[i:i+1] binary_target = (img > 0.5).float() # Forward pass with reduced precision with torch.cuda.amp.autocast(): vae_output = vae(img) chunk_loss = ( 0.9 * F.l1_loss(vae_output.sample, binary_target) + 0.1 * F.mse_loss(vae_output.sample, binary_target) ) / num_chunks # Scale loss for accumulation # Backward pass accelerator.backward(chunk_loss) # Clean up del vae_output, chunk_loss torch.cuda.empty_cache() # Update VAE weights after accumulation if accelerator.sync_gradients: accelerator.clip_grad_norm_(vae.parameters(), max_grad_norm) vae_optimizer.step() vae_optimizer.zero_grad() # Final cleanup del pixel_values torch.cuda.empty_cache() print_gpu_memory() print("After cleanup") # Increment batch step by the actual number of examples processed batch_step += train_batch_size # Log metrics if accelerator.is_main_process: current_lr = unet_optimizer.param_groups[0]['lr'] # Update progress bar with actual training step progress_bar.n = batch_step # Set absolute progress progress_bar.refresh() progress_bar.set_postfix(loss=f"{unet_loss.detach().item():.4f}", lr=f"{current_lr:.2e}") logger.info(f"DEBUG: Step {batch_step}, Next checkpoint at: {min([s for s in checkpoint_steps if s > batch_step], default=None)}") # Check if we should save checkpoint for checkpoint_step in checkpoint_steps: if batch_step >= checkpoint_step and checkpoint_step > last_saved_checkpoint: logger.info(f"\n{'='*20} ATTEMPTING TO SAVE CHECKPOINT {batch_step} {'='*20}") logger.info(f"Current step: {batch_step}") logger.info(f"Checkpoint steps: {checkpoint_steps}") logger.info(f"Is main process: {accelerator.is_main_process}") try: checkpoint_dir = output_dir / f"checkpoint-{checkpoint_step}" logger.info(f"Creating checkpoint directory at: {checkpoint_dir}") checkpoint_dir.mkdir(parents=True, exist_ok=True) if not checkpoint_dir.exists(): raise RuntimeError(f"Failed to create checkpoint directory: {checkpoint_dir}") # Save model weights unwrapped_unet = accelerator.unwrap_model(unet) model_path = checkpoint_dir / "unet.safetensors" logger.info(f"Saving model to: {model_path}") save_file(unwrapped_unet.state_dict(), str(model_path)) logger.info("Model weights saved successfully") # Save optimizer state torch.save(unet_optimizer.state_dict(), checkpoint_dir / "optimizer.bin") logger.info("Optimizer state saved successfully") # Save VAE model weights unwrapped_vae = accelerator.unwrap_model(vae) model_path = checkpoint_dir / "vae.safetensors" logger.info(f"Saving VAE model to: {model_path}") save_file(unwrapped_vae.state_dict(), str(model_path)) logger.info("VAE model weights saved successfully") # Save VAE optimizer state torch.save(vae_optimizer.state_dict(), checkpoint_dir / "vae_optimizer.bin") logger.info("VAE optimizer state saved successfully") # Save training configuration training_args = { "learning_rate": learning_rate, "num_train_steps": total_train_steps, "warmup_steps": warmup_steps, "max_grad_norm": max_grad_norm, "train_batch_size": train_batch_size, "gradient_accumulation_steps": gradient_accumulation_steps, "mixed_precision": mixed_precision, "image_size": image_size, "last_step": batch_step } with open(checkpoint_dir / "training_args.json", "w") as f: json.dump(training_args, f, indent=2) # Save accelerator state accelerator.save_state(checkpoint_dir) logger.info(f"Successfully saved checkpoint at step {checkpoint_step}") logger.info("="*60 + "\n") # Update last saved checkpoint last_saved_checkpoint = checkpoint_step except Exception as e: logger.error(f"Error during checkpoint saving at step {batch_step}: {str(e)}") finally: cleanup() torch.cuda.empty_cache() # Log progress every 100 steps if batch_step % 100 == 0 and accelerator.is_main_process: current_lr = unet_optimizer.param_groups[0]['lr'] logger.info(f"Step {batch_step}, Loss: {unet_loss.item():.4f}, LR: {current_lr:.2e}") # Update progress progress_bar.update(1) # Update learning rate lr_scheduler.step() vae_lr_scheduler.step() if batch_step >= max_train_steps: break # Save the final model and metrics accelerator.wait_for_everyone() if accelerator.is_main_process: logger.info("Saving final model...") # Save final metrics monitor.save_metrics() # Clear cache before final save cleanup() # Save final model unet = accelerator.unwrap_model(unet) vae = accelerator.unwrap_model(vae) pipeline = StableDiffusionPipeline.from_pretrained( pretrained_model_name, unet=unet, vae=vae, torch_dtype=torch.float16 ) pipeline.save_pretrained(output_dir / "final_model") logger.info(f"Training completed. Model saved to {output_dir / 'final_model'}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data_dir", type=str, required=True, help="Path to data directory") parser.add_argument("--output_dir", type=str, required=True, help="Path to save model") parser.add_argument("--pretrained_model", type=str, default="CompVis/stable-diffusion-v1-4") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--grad_accum", type=int, default=2) parser.add_argument("--lr", type=float, default=2e-4) parser.add_argument("--vae_lr", type=float, default=2e-5) parser.add_argument("--max_steps", type=int, default=6000) parser.add_argument("--mixed_precision", type=str, default="fp16") parser.add_argument("--project_name", type=str, default="kanji-sd") parser.add_argument("--run_name", type=str, default=None) parser.add_argument("--enable_wandb", action="store_true", help="Enable wandb logging") parser.add_argument("--warmup_steps", type=int, default=100) parser.add_argument("--max_grad_norm", type=float, default=1.0) parser.add_argument("--freeze_text_encoder", action="store_true", help="Freeze text encoder") parser.add_argument("--freeze_vae", action="store_true", help="Freeze VAE") parser.add_argument("--enable_gradient_checkpointing", action="store_true", help="Enable gradient checkpointing") parser.add_argument("--enable_mem_efficient_attention", action="store_true", help="Enable memory efficient attention") args = parser.parse_args() train_model( data_dir=args.data_dir, output_dir=args.output_dir, pretrained_model_name=args.pretrained_model, train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.lr, vae_learning_rate=args.vae_lr, max_train_steps=args.max_steps, mixed_precision=args.mixed_precision, project_name=args.project_name, run_name=args.run_name, enable_wandb=args.enable_wandb, warmup_steps=args.warmup_steps, max_grad_norm=args.max_grad_norm, freeze_text_encoder=args.freeze_text_encoder, freeze_vae=args.freeze_vae, enable_gradient_checkpointing=args.enable_gradient_checkpointing, enable_mem_efficient_attention=args.enable_mem_efficient_attention )