import gradio as gr import numpy as np import random import torch import spaces from PIL import Image import os from pipeline_flux_ipa import FluxPipeline from transformer_flux import FluxTransformer2DModel from attention_processor import IPAFluxAttnProcessor2_0 from models.transformer_sd3 import SD3Transformer2DModel from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline from transformers import AutoProcessor, SiglipVisionModel from infer_flux_ipa_siglip import MLPProjModel, IPAdapter from huggingface_hub import hf_hub_download # Constants MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model_path = 'stabilityai/stable-diffusion-3.5-large' image_encoder_path = "google/siglip-so400m-patch14-384" ipadapter_path = hf_hub_download(repo_id="InstantX/SD3.5-Large-IP-Adapter", filename="ip-adapter.bin") transformer = SD3Transformer2DModel.from_pretrained( model_path, subfolder="transformer", torch_dtype=torch.bfloat16 ) pipe = FluxPipeline.from_pretrained( model_path, transformer=transformer, torch_dtype=torch.bfloat16 ) ip_model = IPAdapter(pipe, image_encoder_path, ipadapter_path, device="cuda", num_tokens=128) def resize_img(image, max_size=1024): width, height = image.size scaling_factor = min(max_size / width, max_size / height) new_width = int(width * scaling_factor) new_height = int(height * scaling_factor) return image.resize((new_width, new_height), Image.LANCZOS) @spaces.GPU def process_image( image, prompt, scale, seed, randomize_seed, width, height, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) if image is None: return None, seed # Convert to PIL Image if needed if not isinstance(image, Image.Image): image = Image.fromarray(image) # Resize image image = resize_img(image) # Generate the image result = ip_model.generate( pil_image=image, prompt=prompt, scale=scale, width=width, height=height, seed=seed ) return result[0], seed # UI CSS css = """ #col-container { margin: 0 auto; max-width: 960px; } """ # Create the Gradio interface with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# InstantX's SD3.5 IP Adapter") with gr.Row(): with gr.Column(): input_image = gr.Image( label="Input Image", type="pil" ) scale = gr.Slider( label="Image Scale", minimum=0.0, maximum=1.0, step=0.1, value=0.7, ) prompt = gr.Text( label="Prompt", max_lines=1, placeholder="Enter your prompt", ) run_button = gr.Button("Generate", variant="primary") with gr.Column(): result = gr.Image(label="Result") with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) run_button.click( fn=process_image, inputs=[ input_image, prompt, scale, seed, randomize_seed, width, height, ], outputs=[result, seed], ) if __name__ == "__main__": demo.launch()