kai-2054 commited on
Commit
81ec919
·
verified ·
1 Parent(s): ac2ca83

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ os.system("pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers")
4
+ os.system("pip install -e git+https://github.com/alvanli/RDM-Region-Aware-Diffusion-Model.git@main#egg=guided_diffusion")
5
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
6
+
7
+ import math
8
+ import random
9
+
10
+ import gradio as gr
11
+ import torch
12
+ from PIL import Image, ImageOps
13
+ from run_edit import run_model
14
+ from cool_models import make_models
15
+
16
+ help_text = """"""
17
+
18
+
19
+ def main():
20
+ segmodel, model, diffusion, ldm, bert, clip_model, model_params = make_models()
21
+
22
+ def load_sample():
23
+ SAMPLE_IMAGE = "./flower1.jpg"
24
+ input_image = Image.open(SAMPLE_IMAGE)
25
+ from_text = "a flower"
26
+ instruction = "a sunflower"
27
+ negative_prompt = ""
28
+ seed = 42
29
+ guidance_scale = 5.0
30
+ clip_guidance_scale = 150
31
+ cutn = 16
32
+ l2_sim_lambda = 10_000
33
+
34
+ edited_image_1 = run_model(
35
+ segmodel, model, diffusion, ldm, bert, clip_model, model_params,
36
+ from_text, instruction, negative_prompt, input_image.convert('RGB'), seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
37
+ )
38
+
39
+ return [
40
+ input_image, from_text, instruction, negative_prompt, seed, guidance_scale,
41
+ clip_guidance_scale, cutn, l2_sim_lambda, edited_image_1
42
+ ]
43
+
44
+
45
+ def generate(
46
+ input_image: Image.Image,
47
+ from_text: str,
48
+ instruction: str,
49
+ negative_prompt: str,
50
+ randomize_seed: bool,
51
+ seed: int,
52
+ guidance_scale: float,
53
+ clip_guidance_scale: float,
54
+ cutn: int,
55
+ l2_sim_lambda: float
56
+ ):
57
+ seed = random.randint(0, 100000) if randomize_seed else seed
58
+
59
+ if instruction == "":
60
+ return [seed, input_image]
61
+
62
+ generator = torch.manual_seed(seed)
63
+
64
+ edited_image_1 = run_model(
65
+ segmodel, model, diffusion, ldm, bert, clip_model, model_params,
66
+ from_text, instruction, negative_prompt, input_image.convert('RGB'), seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
67
+ )
68
+
69
+ return [seed, edited_image_1]
70
+
71
+ def reset():
72
+ return [
73
+ "Randomize Seed", 42, None, 5.0,
74
+ 150, 16, 10000
75
+ ]
76
+
77
+ with gr.Blocks() as demo:
78
+ gr.Markdown("""
79
+ #### RDM: Region-Aware Diffusion for Zero-shot Text-driven Image Editing
80
+ Original Github Repo: https://github.com/haha-lisa/RDM-Region-Aware-Diffusion-Model <br/>
81
+ Instructions: <br/>
82
+ - In the "From Text" field, specify the object you are trying to modify,
83
+ - In the "edit instruction" field, specify what you want that area to be turned into
84
+ """)
85
+ with gr.Row():
86
+ with gr.Column(scale=1, min_width=100):
87
+ generate_button = gr.Button("Generate")
88
+ with gr.Column(scale=1, min_width=100):
89
+ load_button = gr.Button("Load Example")
90
+ with gr.Column(scale=1, min_width=100):
91
+ reset_button = gr.Button("Reset")
92
+ with gr.Column(scale=3):
93
+ from_text = gr.Textbox(lines=1, label="From Text", interactive=True)
94
+ instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True)
95
+ negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", interactive=True)
96
+
97
+ with gr.Row():
98
+ input_image = gr.Image(label="Input Image", type="pil", interactive=True)
99
+ edited_image_1 = gr.Image(label=f"Edited Image", type="pil", interactive=False)
100
+ # edited_image_2 = gr.Image(label=f"Edited Image", type="pil", interactive=False)
101
+ input_image.style(height=512, width=512)
102
+ edited_image_1.style(height=512, width=512)
103
+ # edited_image_2.style(height=512, width=512)
104
+
105
+ with gr.Row():
106
+ # steps = gr.Number(value=50, precision=0, label="Steps", interactive=True)
107
+ seed = gr.Number(value=1371, precision=0, label="Seed", interactive=True)
108
+ guidance_scale = gr.Number(value=5.0, precision=1, label="Guidance Scale", interactive=True)
109
+ clip_guidance_scale = gr.Number(value=150, precision=1, label="Clip Guidance Scale", interactive=True)
110
+ cutn = gr.Number(value=16, precision=1, label="Number of Cuts", interactive=True)
111
+ l2_sim_lambda = gr.Number(value=10000, precision=1, label="L2 similarity to original image")
112
+
113
+ randomize_seed = gr.Radio(
114
+ ["Fix Seed", "Randomize Seed"],
115
+ value="Randomize Seed",
116
+ type="index",
117
+ show_label=False,
118
+ interactive=True,
119
+ )
120
+ # use_ddim = gr.Checkbox(label="Use 50-step DDIM?", value=True)
121
+ # use_ddpm = gr.Checkbox(label="Use 50-step DDPM?", value=True)
122
+
123
+ gr.Markdown(help_text)
124
+
125
+ generate_button.click(
126
+ fn=generate,
127
+ inputs=[
128
+ input_image, from_text, instruction, negative_prompt, randomize_seed,
129
+ seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
130
+ ],
131
+ outputs=[seed, edited_image_1],
132
+ )
133
+
134
+ load_button.click(
135
+ fn=load_sample,
136
+ inputs=[],
137
+ outputs=[input_image, from_text, instruction, negative_prompt, seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda, edited_image_1],
138
+ )
139
+
140
+
141
+ reset_button.click(
142
+ fn=reset,
143
+ inputs=[],
144
+ outputs=[
145
+ randomize_seed, seed, edited_image_1, guidance_scale,
146
+ clip_guidance_scale, cutn, l2_sim_lambda
147
+ ],
148
+ )
149
+
150
+ demo.queue(concurrency_count=1)
151
+ demo.launch(share=False, server_name="0.0.0.0")
152
+
153
+
154
+ if __name__ == "__main__":
155
+ main()