from __future__ import annotations
import os
os.system("pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers")
os.system("pip install -e git+https://github.com/alvanli/RDM-Region-Aware-Diffusion-Model.git@main#egg=guided_diffusion")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
import math
import random
import gradio as gr
import torch
from PIL import Image, ImageOps
from run_edit import run_model
from cool_models import make_models
help_text = """"""
def main():
segmodel, model, diffusion, ldm, bert, clip_model, model_params = make_models()
def load_sample():
SAMPLE_IMAGE = "./flower1.jpg"
input_image = Image.open(SAMPLE_IMAGE)
from_text = "a flower"
instruction = "a sunflower"
negative_prompt = ""
seed = 42
guidance_scale = 5.0
clip_guidance_scale = 150
cutn = 16
l2_sim_lambda = 10_000
edited_image_1 = run_model(
segmodel, model, diffusion, ldm, bert, clip_model, model_params,
from_text, instruction, negative_prompt, input_image.convert('RGB'), seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
)
return [
input_image, from_text, instruction, negative_prompt, seed, guidance_scale,
clip_guidance_scale, cutn, l2_sim_lambda, edited_image_1
]
def generate(
input_image: Image.Image,
from_text: str,
instruction: str,
negative_prompt: str,
randomize_seed: bool,
seed: int,
guidance_scale: float,
clip_guidance_scale: float,
cutn: int,
l2_sim_lambda: float
):
seed = random.randint(0, 100000) if randomize_seed else seed
if instruction == "":
return [seed, input_image]
generator = torch.manual_seed(seed)
edited_image_1 = run_model(
segmodel, model, diffusion, ldm, bert, clip_model, model_params,
from_text, instruction, negative_prompt, input_image.convert('RGB'), seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
)
return [seed, edited_image_1]
def reset():
return [
"Randomize Seed", 42, None, 5.0,
150, 16, 10000
]
with gr.Blocks() as demo:
gr.HTML("""
RDM: Region-Aware Diffusion for Zero-shot Text-driven Image Editing
In the "From Text" field, specify the object you are trying to modify, in the "edit instruction" field, specify what you want that area to be turned into
""")
with gr.Row():
with gr.Column(scale=1, min_width=100):
generate_button = gr.Button("Generate")
with gr.Column(scale=1, min_width=100):
load_button = gr.Button("Load Example")
with gr.Column(scale=1, min_width=100):
reset_button = gr.Button("Reset")
with gr.Column(scale=3):
from_text = gr.Textbox(lines=1, label="From Text", interactive=True)
instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True)
negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", interactive=True)
with gr.Row():
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
edited_image_1 = gr.Image(label=f"Edited Image", type="pil", interactive=False)
# edited_image_2 = gr.Image(label=f"Edited Image", type="pil", interactive=False)
input_image.style(height=512, width=512)
edited_image_1.style(height=512, width=512)
# edited_image_2.style(height=512, width=512)
with gr.Row():
# steps = gr.Number(value=50, precision=0, label="Steps", interactive=True)
seed = gr.Number(value=1371, precision=0, label="Seed", interactive=True)
guidance_scale = gr.Number(value=5.0, precision=1, label="Guidance Scale", interactive=True)
clip_guidance_scale = gr.Number(value=150, precision=1, label="Clip Guidance Scale", interactive=True)
cutn = gr.Number(value=16, precision=1, label="Number of Cuts", interactive=True)
l2_sim_lambda = gr.Number(value=10000, precision=1, label="L2 similarity to original image")
randomize_seed = gr.Radio(
["Fix Seed", "Randomize Seed"],
value="Randomize Seed",
type="index",
show_label=False,
interactive=True,
)
# use_ddim = gr.Checkbox(label="Use 50-step DDIM?", value=True)
# use_ddpm = gr.Checkbox(label="Use 50-step DDPM?", value=True)
gr.Markdown(help_text)
generate_button.click(
fn=generate,
inputs=[
input_image, from_text, instruction, negative_prompt, randomize_seed,
seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
],
outputs=[seed, edited_image_1],
)
load_button.click(
fn=load_sample,
inputs=[],
outputs=[input_image, from_text, instruction, negative_prompt, seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda, edited_image_1],
)
reset_button.click(
fn=reset,
inputs=[],
outputs=[
randomize_seed, seed, edited_image_1, guidance_scale,
clip_guidance_scale, cutn, l2_sim_lambda
],
)
demo.queue(concurrency_count=1)
demo.launch(share=False, server_name="0.0.0.0")
if __name__ == "__main__":
main()