import logging import random import warnings import os import gradio as gr import numpy as np import torch from diffusers import FluxControlNetModel from diffusers.pipelines import FluxControlNetPipeline from gradio_imageslider import ImageSlider from PIL import Image from huggingface_hub import snapshot_download import gc # Clear memory gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() css = """ #col-container { margin: 0 auto; max-width: 512px; } """ # Device configuration device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float32 huggingface_token = os.getenv("HF_TOKEN") # Modified model configuration model_config = { "low_cpu_mem_usage": True, "torch_dtype": dtype, "use_safetensors": False, # Disabled safetensors } model_path = snapshot_download( repo_id="black-forest-labs/FLUX.1-dev", repo_type="model", ignore_patterns=["*.md", "*..gitattributes", "*.bin"], local_dir="FLUX.1-dev", token=huggingface_token, ) # Load models with modified configuration try: controlnet = FluxControlNetModel.from_pretrained( "jasperai/Flux.1-dev-Controlnet-Upscaler", **model_config ) controlnet.to(device) pipe = FluxControlNetPipeline.from_pretrained( model_path, controlnet=controlnet, **model_config ) pipe.to(device) except Exception as e: print(f"Error loading models: {str(e)}") raise # Enable optimizations pipe.enable_attention_slicing(1) pipe.enable_vae_slicing() MAX_SEED = 1000000 MAX_PIXEL_BUDGET = 64 * 64 def process_input(input_image, upscale_factor): input_image = input_image.convert('RGB') w, h = input_image.size max_size = int(np.sqrt(MAX_PIXEL_BUDGET)) new_w = min(w, max_size) new_h = min(h, max_size) input_image = input_image.resize((new_w, new_h), Image.LANCZOS) w = new_w - new_w % 8 h = new_h - new_h % 8 return input_image.resize((w, h)), w, h def infer( seed, randomize_seed, input_image, num_inference_steps, upscale_factor, controlnet_conditioning_scale, progress=gr.Progress(track_tqdm=True), ): try: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() if randomize_seed: seed = random.randint(0, MAX_SEED) input_image, w, h = process_input(input_image, upscale_factor) with torch.inference_mode(): generator = torch.Generator(device=device).manual_seed(seed) image = pipe( prompt="", control_image=input_image, controlnet_conditioning_scale=controlnet_conditioning_scale, num_inference_steps=num_inference_steps, guidance_scale=1.5, height=h, width=w, generator=generator, ).images[0] gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return [input_image, image, seed] except Exception as e: gr.Error(f"Error: {str(e)}") return None with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: with gr.Row(): run_button = gr.Button(value="Run") with gr.Row(): with gr.Column(scale=4): input_im = gr.Image(label="Input Image", type="pil") with gr.Column(scale=1): num_inference_steps = gr.Slider( label="Steps", minimum=1, maximum=10, step=1, value=5, ) upscale_factor = gr.Slider( label="Scale", minimum=1, maximum=1, step=1, value=1, ) controlnet_conditioning_scale = gr.Slider( label="Control Scale", minimum=0.1, maximum=0.3, step=0.1, value=0.2, ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) randomize_seed = gr.Checkbox(label="Random Seed", value=True) with gr.Row(): result = ImageSlider(label="Result", type="pil", interactive=True) current_dir = os.path.dirname(os.path.abspath(__file__)) examples = gr.Examples( examples=[ [42, False, os.path.join(current_dir, "z1.webp"), 5, 1, 0.2], [42, False, os.path.join(current_dir, "z2.webp"), 5, 1, 0.2], ], inputs=[ seed, randomize_seed, input_im, num_inference_steps, upscale_factor, controlnet_conditioning_scale, ], fn=infer, outputs=result, cache_examples=False, ) gr.on( [run_button.click], fn=infer, inputs=[ seed, randomize_seed, input_im, num_inference_steps, upscale_factor, controlnet_conditioning_scale, ], outputs=result, show_api=False, ) # Launch configuration demo.queue(max_size=1).launch( share=False, debug=True, show_error=True, max_threads=1, enable_queue=True, cache_examples=False, quiet=True, )