flux-IP-adapter / app.py
multimodalart's picture
Update app.py
07afe68 verified
raw
history blame
4.21 kB
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image
import os
from pipeline_flux_ipa import FluxPipeline
from transformer_flux import FluxTransformer2DModel
from attention_processor import IPAFluxAttnProcessor2_0
from transformers import AutoProcessor, SiglipVisionModel
from infer_flux_ipa_siglip import MLPProjModel, IPAdapter
from huggingface_hub import hf_hub_download
import spaces
# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
image_encoder_path = "google/siglip-so400m-patch14-384"
ipadapter_path = hf_hub_download(repo_id="InstantX/FLUX.1-dev-IP-Adapter", filename="ip-adapter.bin")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
torch_dtype=torch.bfloat16
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16
)
ip_model = IPAdapter(pipe, image_encoder_path, ipadapter_path, device="cuda", num_tokens=128)
def resize_img(image, max_size=1024):
width, height = image.size
scaling_factor = min(max_size / width, max_size / height)
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
return image.resize((new_width, new_height), Image.LANCZOS)
@spaces.GPU
def process_image(
image,
prompt,
scale,
seed,
randomize_seed,
width,
height,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
if image is None:
return None, seed
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Resize image
image = resize_img(image)
# Generate the image
result = ip_model.generate(
pil_image=image,
prompt=prompt,
scale=scale,
width=width,
height=height,
seed=seed
)
return result[0], seed
# UI CSS
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
"""
# Create the Gradio interface
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Image Processing Model")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="pil"
)
prompt = gr.Text(
label="Prompt",
max_lines=1,
placeholder="Enter your prompt",
)
run_button = gr.Button("Process", variant="primary")
with gr.Column():
result = gr.Image(label="Result")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=960,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1280,
)
scale = gr.Slider(
label="Scale",
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7,
)
run_button.click(
fn=process_image,
inputs=[
input_image,
prompt,
scale,
seed,
randomize_seed,
width,
height,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()