import os import gc import time import torch import torch.nn.functional as F from PIL import Image as img from PIL.Image import Image from typing import Optional, Type from dataclasses import dataclass from diffusers import ( FluxTransformer2DModel, DiffusionPipeline, AutoencoderTiny ) from transformers import T5EncoderModel from huggingface_hub.constants import HF_HUB_CACHE from torchao.quantization import quantize_, int8_weight_only from first_block_cache.diffusers_adapters import apply_cache_on_pipe from pipelines.models import TextToImageRequest from torch import Generator # Configuration @dataclass class Config: CKPT_ID: str = "black-forest-labs/FLUX.1-schnell" CKPT_REVISION: str = "741f7c3ce8b383c54771c7003378a50191e9efe9" DEVICE: str = "cuda" DTYPE = torch.bfloat16 PYTORCH_CUDA_ALLOC_CONF: str = "expandable_segments:True" # Initialize global settings def init_global_settings(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True os.environ['PYTORCH_CUDA_ALLOC_CONF'] = Config.PYTORCH_CUDA_ALLOC_CONF # Tensor comparison utilities class TensorComparator: @staticmethod def orig_comparison(t1, t2, *, threshold=0.85): mean_diff = (t1 - t2).abs().mean() mean_t1 = t1.abs().mean() diff = mean_diff / mean_t1 return diff.item() < threshold @staticmethod def mse_comparison(t1, t2, threshold=0.95): mse = F.mse_loss(t1, t2) return mse.item() < threshold @staticmethod def relative_comparison(t1, t2, threshold=0.15): with torch.no_grad(): mean_diff = torch.mean(torch.abs(t1 - t2)) mean_t1 = torch.mean(torch.abs(t1)) relative_diff = mean_diff / (mean_t1 + 1e-8) return relative_diff.item() < threshold @staticmethod def normalized_comparison(t1, t2, threshold=0.85): with torch.no_grad(): t1_norm = (t1 - t1.mean()) / (t1.std() + 1e-8) t2_norm = (t2 - t2.mean()) / (t2.std() + 1e-8) diff = torch.mean(torch.abs(t1_norm - t2_norm)) return diff.item() < threshold @staticmethod def l1_comparison(t1, t2, threshold=0.85): with torch.no_grad(): l1_dist = torch.nn.L1Loss()(t1, t2) return l1_dist.item() < threshold @staticmethod def max_diff_comparison(t1, t2, threshold=0.85): with torch.no_grad(): max_diff = torch.max(torch.abs(t1 - t2)) return max_diff.item() < threshold # Memory management class MemoryManager: @staticmethod def empty_cache(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() # Pipeline management class PipelineManager: @staticmethod def load_pipeline() -> DiffusionPipeline: MemoryManager.empty_cache() text_encoder_2 = T5EncoderModel.from_pretrained( "city96/t5-v1_1-xxl-encoder-bf16", revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=Config.DTYPE ).to(memory_format=torch.channels_last) vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=Config.DTYPE) path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a") model = FluxTransformer2DModel.from_pretrained( path, torch_dtype=Config.DTYPE, use_safetensors=False ).to(memory_format=torch.channels_last) pipeline = DiffusionPipeline.from_pretrained( Config.CKPT_ID, vae=vae, revision=Config.CKPT_REVISION, transformer=model, text_encoder_2=text_encoder_2, torch_dtype=Config.DTYPE, ).to(Config.DEVICE) apply_cache_on_pipe(pipeline) pipeline.to(memory_format=torch.channels_last) pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune") quantize_(pipeline.vae, int8_weight_only()) PipelineManager._warmup(pipeline) return pipeline @staticmethod def _warmup(pipeline): for _ in range(3): pipeline( prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 ) @staticmethod @torch.no_grad() def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image: try: image = pipeline( request.prompt, generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil" ).images[0] except: print("using backup") image = img.open("./RobertML.png") return image # Initialize global settings init_global_settings() # Keep original interface load_pipeline = PipelineManager.load_pipeline infer = PipelineManager.infer are_two_tensors_similar = TensorComparator.orig_comparison are_two_tensors_similar_relative = TensorComparator.relative_comparison are_two_tensors_similar_normalized = TensorComparator.normalized_comparison are_two_tensors_similar_l1 = TensorComparator.l1_comparison are_two_tensors_similar_max_diff = TensorComparator.max_diff_comparison empty_cache = MemoryManager.empty_cache