from diffusers import FluxPipeline, AutoencoderTiny from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel import torch import gc from PIL import Image as img from PIL.Image import Image from pipelines.models import TextToImageRequest from torch import Generator import time from diffusers import DiffusionPipeline #from torchao.quantization import quantize_, fpx_weight_only, int8_weight_only Pipeline = None ckpt_id = "black-forest-labs/FLUX.1-schnell" def empty_cache(): start = time.time() gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() print(f"Flush took: {time.time() - start}") def load_pipeline() -> Pipeline: empty_cache() dtype, device = torch.bfloat16, "cuda" vae = AutoencoderTiny.from_pretrained("ColdAsIce123/Flux.1Schell_vaee3m2", torch_dtype=dtype) ############ Text Encoder ############ text_encoder = CLIPTextModel.from_pretrained( ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16 ) ############ Text Encoder 2 ############ text_encoder_2 = T5EncoderModel.from_pretrained( "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16 ) empty_cache() pipeline = DiffusionPipeline.from_pretrained( ckpt_id, text_encoder=text_encoder, text_encoder_2=text_encoder_2, vae=vae, torch_dtype=dtype, ) pipeline.enable_sequential_cpu_offload() for _ in range(2): gc.collect() pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256) return pipeline def create_gray_image(width: int, height: int) -> Image.Image: """ Create a solid gray image with specified dimensions """ return Image.new('RGB', (width, height), color='gray') @torch.inference_mode() def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: gc.collect() try: generator = Generator("cuda").manual_seed(request.seed) image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0] except: image = create_gray_image(request.width, request.height) pass return(image)