from __future__ import annotations import os os.system("pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers") os.system("pip install -e git+https://github.com/alvanli/RDM-Region-Aware-Diffusion-Model.git@main#egg=guided_diffusion") os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" import math import random import gradio as gr import torch from PIL import Image, ImageOps from run_edit import run_model from cool_models import make_models help_text = """""" def main(): segmodel, model, diffusion, ldm, bert, clip_model, model_params = make_models() def load_sample(): SAMPLE_IMAGE = "./flower1.jpg" input_image = Image.open(SAMPLE_IMAGE) from_text = "a flower" instruction = "a sunflower" negative_prompt = "" seed = 42 guidance_scale = 5.0 clip_guidance_scale = 150 cutn = 16 l2_sim_lambda = 10_000 edited_image_1 = run_model( segmodel, model, diffusion, ldm, bert, clip_model, model_params, from_text, instruction, negative_prompt, input_image.convert('RGB'), seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda ) return [ input_image, from_text, instruction, negative_prompt, seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda, edited_image_1 ] def generate( input_image: Image.Image, from_text: str, instruction: str, negative_prompt: str, randomize_seed: bool, seed: int, guidance_scale: float, clip_guidance_scale: float, cutn: int, l2_sim_lambda: float ): seed = random.randint(0, 100000) if randomize_seed else seed if instruction == "": return [seed, input_image] generator = torch.manual_seed(seed) edited_image_1 = run_model( segmodel, model, diffusion, ldm, bert, clip_model, model_params, from_text, instruction, negative_prompt, input_image.convert('RGB'), seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda ) return [seed, edited_image_1] def reset(): return [ "Randomize Seed", 42, None, 5.0, 150, 16, 10000 ] with gr.Blocks() as demo: gr.HTML("""

RDM: Region-Aware Diffusion for Zero-shot Text-driven Image Editing

In the "From Text" field, specify the object you are trying to modify, in the "edit instruction" field, specify what you want that area to be turned into """) with gr.Row(): with gr.Column(scale=1, min_width=100): generate_button = gr.Button("Generate") with gr.Column(scale=1, min_width=100): load_button = gr.Button("Load Example") with gr.Column(scale=1, min_width=100): reset_button = gr.Button("Reset") with gr.Column(scale=3): from_text = gr.Textbox(lines=1, label="From Text", interactive=True) instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True) negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", interactive=True) with gr.Row(): input_image = gr.Image(label="Input Image", type="pil", interactive=True) edited_image_1 = gr.Image(label=f"Edited Image", type="pil", interactive=False) # edited_image_2 = gr.Image(label=f"Edited Image", type="pil", interactive=False) input_image.style(height=512, width=512) edited_image_1.style(height=512, width=512) # edited_image_2.style(height=512, width=512) with gr.Row(): # steps = gr.Number(value=50, precision=0, label="Steps", interactive=True) seed = gr.Number(value=1371, precision=0, label="Seed", interactive=True) guidance_scale = gr.Number(value=5.0, precision=1, label="Guidance Scale", interactive=True) clip_guidance_scale = gr.Number(value=150, precision=1, label="Clip Guidance Scale", interactive=True) cutn = gr.Number(value=16, precision=1, label="Number of Cuts", interactive=True) l2_sim_lambda = gr.Number(value=10000, precision=1, label="L2 similarity to original image") randomize_seed = gr.Radio( ["Fix Seed", "Randomize Seed"], value="Randomize Seed", type="index", show_label=False, interactive=True, ) # use_ddim = gr.Checkbox(label="Use 50-step DDIM?", value=True) # use_ddpm = gr.Checkbox(label="Use 50-step DDPM?", value=True) gr.Markdown(help_text) generate_button.click( fn=generate, inputs=[ input_image, from_text, instruction, negative_prompt, randomize_seed, seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda ], outputs=[seed, edited_image_1], ) load_button.click( fn=load_sample, inputs=[], outputs=[input_image, from_text, instruction, negative_prompt, seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda, edited_image_1], ) reset_button.click( fn=reset, inputs=[], outputs=[ randomize_seed, seed, edited_image_1, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda ], ) demo.queue(concurrency_count=1) demo.launch(share=False, server_name="0.0.0.0") if __name__ == "__main__": main()