import os import gc import torch from PIL.Image import Image from dataclasses import dataclass from diffusers import DiffusionPipeline, AutoencoderTiny, FluxTransformer2DModel from transformers import T5EncoderModel from huggingface_hub.constants import HF_HUB_CACHE from torchao.quantization import quantize_, int8_weight_only, float8_weight_only from caching import apply_cache_on_pipe from pipelines.models import TextToImageRequest from torch import Generator # Configuration settings using a dataclass for clarity @dataclass class Config: CKPT_ID: str = "black-forest-labs/FLUX.1-schnell" CKPT_REVISION: str = "741f7c3ce8b383c54771c7003378a50191e9efe9" DEVICE: str = "cuda" DTYPE = torch.bfloat16 PYTORCH_CUDA_ALLOC_CONF: str = "expandable_segments:True" def _initialize_environment(): """Set up PyTorch and CUDA environment variables for optimal performance.""" torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True os.environ['PYTORCH_CUDA_ALLOC_CONF'] = Config.PYTORCH_CUDA_ALLOC_CONF def _clear_gpu_memory(): """Free up GPU memory to prevent memory-related issues.""" gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def _load_text_encoder_model(): """Load the text encoder model with specified configuration.""" return T5EncoderModel.from_pretrained( "city96/t5-v1_1-xxl-encoder-bf16", revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=Config.DTYPE ).to(memory_format=torch.channels_last) def _load_vae_model(): """Load the variational autoencoder (VAE) model with specified configuration.""" return AutoencoderTiny.from_pretrained( "RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=Config.DTYPE ) # return AutoencoderTiny.from_pretrained( # "manbeast3b/FLUX.1-schnell-taef1-float8", # revision="7c538d53ec698509788ed88b1305c6bb019bdb4d", # torch_dtype=Config.DTYPE # ) def _load_transformer_model(): """Load the transformer model from a specific cached path.""" # transformer_path = os.path.join( # HF_HUB_CACHE,"models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146", # ) transformer_path = os.path.join( HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146", "transformer" ) return FluxTransformer2DModel.from_pretrained( transformer_path, torch_dtype=Config.DTYPE, use_safetensors=False ).to(memory_format=torch.channels_last) def _warmup_pipeline(pipeline): """Warm up the pipeline by running it with an empty prompt to initialize internal caches.""" for _ in range(3): pipeline(prompt=" ") def load_pipeline(): """ Load and configure the diffusion pipeline for text-to-image generation. Returns: DiffusionPipeline: The configured pipeline ready for inference. """ _clear_gpu_memory() # Load individual components text_encoder = _load_text_encoder_model() vae = _load_vae_model() transformer = _load_transformer_model() # Assemble the diffusion pipeline pipeline = DiffusionPipeline.from_pretrained( Config.CKPT_ID, vae=vae, revision=Config.CKPT_REVISION, transformer=transformer, text_encoder_2=text_encoder, torch_dtype=Config.DTYPE, ).to(Config.DEVICE) # Apply optimizations apply_cache_on_pipe(pipeline) pipeline.to(memory_format=torch.channels_last) pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune") quantize_(pipeline.vae, int8_weight_only()) quantize_(pipeline.vae, float8_weight_only()) # Warm up the pipeline to ensure readiness _warmup_pipeline(pipeline) return pipeline @torch.no_grad() def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image: """ Generate an image from a text prompt using the diffusion pipeline. Args: request (TextToImageRequest): The request containing the prompt and image parameters. pipeline (DiffusionPipeline): The pre-loaded diffusion pipeline. generator (Generator): The random seed generator for reproducibility. Returns: Image: The generated image in PIL format. """ image = pipeline( prompt=request.prompt, generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil" ).images[0] return image # Initialize environment settings when the module is imported _initialize_environment() # For compatibility with other scripts, alias load_pipeline as load load = load_pipeline