import base64 import logging import os from io import BytesIO from typing import Any, Dict, Tuple import torch from diffusers import FluxControlNetModel, FluxControlNetPipeline, FluxPipeline from PIL import Image # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Set device to GPU; raise error if not available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device.type != "cuda": raise ValueError("This handler requires a GPU to run.") class EndpointHandler: MAX_SEED = 1000000 MAX_PIXEL_BUDGET = 10240 * 10240 # Adjust as needed def __init__(self, path: str = ""): """ Initialize the Flux Pipeline and Flux ControlNet Pipeline with optional upscaling support. """ # Retrieve Hugging Face token from environment variables hf_token = os.getenv("HUGGINGFACE_TOKEN") if hf_token is None: raise ValueError("HUGGINGFACE_TOKEN environment variable not set.") # Initialize FluxPipeline for image generation logger.info("Loading FluxPipeline...") self.pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=hf_token ) self.pipe.to(device) logger.info("FluxPipeline initialized.") # Initialize FluxControlNetPipeline for upscaling logger.info( "Loading FluxControlNetModel and FluxControlNetPipeline for upscaling..." ) self.controlnet = FluxControlNetModel.from_pretrained( "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16, token=hf_token, ) self.upscale_pipe = FluxControlNetPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", controlnet=self.controlnet, torch_dtype=torch.bfloat16, token=hf_token, ) self.upscale_pipe.to(device) logger.info("FluxControlNetPipeline initialized for upscaling.") def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Handle incoming inference requests with optional upscaling. """ try: # Check if upscaling is requested upscale_factor = data.get("upscale_factor", None) if upscale_factor: # Upscaling logic # Validate upscale_factor if upscale_factor not in [2, 4, 8]: raise ValueError( "Unsupported upscale factor. Choose from 2, 4, or 8." ) # Control image is required for upscaling control_image_b64 = data.get("control_image", None) if not control_image_b64: raise ValueError( "Control image 'control_image' is required for upscaling." ) # Decode and process the input image input_image = self.decode_base64_image(control_image_b64) # Process the input image processed_image, w_original, h_original, was_resized = ( self.process_input(input_image, upscale_factor) ) logger.info(f"Original Image Size: {w_original}x{h_original}") logger.info(f"Upscale Factor: {upscale_factor}") # Resize control image based on upscale_factor control_image = processed_image.resize( (w_original * upscale_factor, h_original * upscale_factor), Image.LANCZOS, ) logger.info(f"Control Image Size after Upscale: {control_image.size}") # Extract hyperparameters with default values num_inference_steps = data.get("num_inference_steps", 28) guidance_scale = data.get("guidance_scale", 3.5) controlnet_conditioning_scale = data.get( "controlnet_conditioning_scale", 0.6 ) # Empty prompt as per the user's code prompt = data.get("inputs", "") logger.info("Running inference with upscaling...") output = self.upscale_pipe( prompt=prompt, control_image=control_image, controlnet_conditioning_scale=controlnet_conditioning_scale, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=control_image.size[1], width=control_image.size[0], ) logger.info("Inference completed.") # Get the generated image generated_image = output.images[0] logger.info( f"Generated Image Size before optional resize: {generated_image.size}" ) if was_resized: # Resize to target desired size based on original dimensions target_w = w_original * upscale_factor target_h = h_original * upscale_factor generated_image = generated_image.resize( (target_w, target_h), Image.LANCZOS ) logger.info(f"Resized output image to {target_w}x{target_h}.") logger.info(f"Final Generated Image Size: {generated_image.size}") else: # Regular image generation without upscaling # Extract hyperparameters with default values prompt = data.get("inputs", "") if not prompt: raise ValueError("Prompt 'inputs' is required when not upscaling.") num_inference_steps = data.get("num_inference_steps", 50) guidance_scale = data.get("guidance_scale", 3.5) height = data.get("height", 1024) width = data.get("width", 1024) max_sequence_length = data.get("max_sequence_length", 512) seed = data.get("seed", None) if seed is not None: generator = torch.Generator(device).manual_seed(seed) else: generator = None logger.info("Running inference without upscaling...") output = self.pipe( prompt=prompt, height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, max_sequence_length=max_sequence_length, generator=generator, ) logger.info("Inference completed.") # Get the generated image generated_image = output.images[0] logger.info(f"Generated Image Size: {generated_image.size}") # Encode the generated image to base64 image_b64 = self.encode_image_to_base64(generated_image) logger.info("Generated image encoded to base64.") return {"image": image_b64} except Exception as e: logger.error(f"Inference failed: {e}") raise ValueError(f"Inference failed: {e}") def process_input( self, input_image: Image.Image, upscale_factor: int ) -> Tuple[Image.Image, int, int, bool]: """ Process the input image by checking pixel budget and resizing if necessary. """ original_w, original_h = input_image.size # Capture original dimensions aspect_ratio = original_w / original_h was_resized = False if original_w * original_h * upscale_factor**2 > self.MAX_PIXEL_BUDGET: new_w = int(aspect_ratio * (self.MAX_PIXEL_BUDGET**0.5) / upscale_factor) new_h = int((self.MAX_PIXEL_BUDGET**0.5) / (aspect_ratio * upscale_factor)) logger.warning( f"Requested output image is too large ({original_w * upscale_factor}x{original_h * upscale_factor}). " f"Resizing input to ({new_w}x{new_h}) pixels to fit the pixel budget." ) input_image = input_image.resize((new_w, new_h), Image.LANCZOS) was_resized = True # Ensure dimensions are multiples of 8 w, h = input_image.size w = w - (w % 8) h = h - (h % 8) input_image = input_image.resize((w, h), Image.LANCZOS) return input_image, original_w, original_h, was_resized @staticmethod def decode_base64_image(image_string: str) -> Image.Image: """ Decode a base64 string to a PIL Image. """ try: base64_image = base64.b64decode(image_string) buffer = BytesIO(base64_image) image = Image.open(buffer).convert("RGB") return image except Exception as e: raise ValueError(f"Failed to decode image: {e}") @staticmethod def encode_image_to_base64(image: Image.Image) -> str: """ Encode a PIL Image to a base64 string. """ try: buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return img_str except Exception as e: raise ValueError(f"Failed to encode image: {e}")