d-edit / app.py
afeng's picture
update
1f36469
raw
history blame
No virus
19.2 kB
import os
import copy
from PIL import Image
import matplotlib
import numpy as np
import gradio as gr
from utils import load_mask, load_mask_edit
from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean
from pathlib import Path
from PIL import Image
from functools import partial
from main import run_main
LENGTH=512 #length of the square area displaying/editing images
TRANSPARENCY = 150 # transparency of the mask in display
def add_mask(mask_np_list_updated, mask_label_list):
mask_new = np.zeros_like(mask_np_list_updated[0])
mask_np_list_updated.append(mask_new)
mask_label_list.append("new")
return mask_np_list_updated, mask_label_list
def create_segmentation(mask_np_list):
viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list))
segmentation = 0
for i, m in enumerate(mask_np_list):
color = matplotlib.colors.to_rgb(viridis(i))
color_mat = np.ones_like(m)
color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2)
color_mat = color_mat * m[:,:,np.newaxis]
segmentation += color_mat
segmentation = Image.fromarray(np.uint8(segmentation*255))
return segmentation
def load_mask_ui(input_folder="example_tmp",load_edit = False):
if not load_edit:
mask_list, mask_label_list = load_mask(input_folder)
else:
mask_list, mask_label_list = load_mask_edit(input_folder)
mask_np_list = []
for m in mask_list:
mask_np_list. append( m.cpu().numpy())
return mask_np_list, mask_label_list
def load_image_ui(load_edit, input_folder="example_tmp"):
try:
for img_path in Path(input_folder).iterdir():
if img_path.name in ["img_512.png"]:
image = Image.open(img_path)
mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
image = image.convert('RGB')
segmentation = create_segmentation(mask_np_list)
print("!!", len(mask_np_list))
return image, segmentation, mask_np_list, mask_label_list, image
except:
print("Image folder invalid: The folder should contain image.png")
return None, None, None, None, None
# def run_edit_text(
# num_tokens,
# num_sampling_steps,
# strength,
# edge_thickness,
# tgt_prompt,
# tgt_idx,
# guidance_scale,
# input_folder="example_tmp"
# ):
# subprocess.run(["python",
# "main.py" ,
# "--text=True",
# "--name={}".format(input_folder),
# "--dpm={}".format("sd"),
# "--resolution={}".format(512),
# "--load_trained",
# "--num_tokens={}".format(num_tokens),
# "--seed={}".format(2024),
# "--guidance_scale={}".format(guidance_scale),
# "--num_sampling_step={}".format(num_sampling_steps),
# "--strength={}".format(strength),
# "--edge_thickness={}".format(edge_thickness),
# "--num_imgs={}".format(2),
# "--tgt_prompt={}".format(tgt_prompt) ,
# "--tgt_index={}".format(tgt_idx)
# ])
# return Image.open(os.path.join(input_folder, "text", "out_text_0.png"))
# def run_optimization(
# num_tokens,
# embedding_learning_rate,
# max_emb_train_steps,
# diffusion_model_learning_rate,
# max_diffusion_train_steps,
# train_batch_size,
# gradient_accumulation_steps,
# input_folder = "example_tmp"
# ):
# subprocess.run(["python",
# "main.py" ,
# "--name={}".format(input_folder),
# "--dpm={}".format("sd"),
# "--resolution={}".format(512),
# "--num_tokens={}".format(num_tokens),
# "--embedding_learning_rate={}".format(embedding_learning_rate),
# "--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate),
# "--max_emb_train_steps={}".format(max_emb_train_steps),
# "--max_diffusion_train_steps={}".format(max_diffusion_train_steps),
# "--train_batch_size={}".format(train_batch_size),
# "--gradient_accumulation_steps={}".format(gradient_accumulation_steps)
# ])
# return
def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
backimg_solid_np = np.array(backimg)
bimg = backimg.copy()
fimg = foreimg.copy()
fimg.putalpha(transparency)
bimg.paste(fimg, (0,0), fimg)
bimg_np = np.array(bimg)
mask_np = mask_np[:,:,np.newaxis]
try:
new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
return Image.fromarray(new_img_np)
except:
import pdb; pdb.set_trace()
def show_segmentation(image, segmentation, flag):
if flag is False:
flag = True
mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8)
image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY)
return image_edit, flag
else:
flag = False
return image,flag
def edit_mask_add(canvas, image, idx, mask_np_list):
mask_sel = mask_np_list[idx]
mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.)
mask_np_list_updated = []
for midx, m in enumerate(mask_np_list):
if midx == idx:
mask_np_list_updated.append(mask_union(mask_sel, mask_new))
else:
mask_np_list_updated.append(m)
priority_list = [0 for _ in range(len(mask_np_list_updated))]
priority_list[idx] = 1
mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list)
mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8)
segmentation = create_segmentation(mask_np_list_updated)
image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY)
return mask_np_list_updated, image_edit
def slider_release(index, image, mask_np_list_updated, mask_label_list):
if index > len(mask_np_list_updated):
return image, "out of range"
else:
mask_np = mask_np_list_updated[index]
mask_label = mask_label_list[index]
segmentation = create_segmentation(mask_np_list_updated)
new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
return new_image, mask_label
def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
try:
assert np.all(sum(mask_np_list_updated)==1)
except:
print("please check mask")
# plt.imsave( "out_mask.png", mask_list_edit[0])
import pdb; pdb.set_trace()
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
# np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask )
np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask )
savepath = os.path.join(input_folder, "seg_current.png")
visualize_mask_list_clean(mask_np_list_updated, savepath)
def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
try:
assert np.all(sum(mask_np_list_updated)==1)
except:
print("please check mask")
# plt.imsave( "out_mask.png", mask_list_edit[0])
import pdb; pdb.set_trace()
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask)
savepath = os.path.join(input_folder, "seg_edited.png")
visualize_mask_list_clean(mask_np_list_updated, savepath)
import shutil
if os.path.isdir("./example_tmp"):
shutil.rmtree("./example_tmp")
from segment import run_segmentation
with gr.Blocks() as demo:
image = gr.State() # store mask
image_loaded = gr.State()
segmentation = gr.State()
mask_np_list = gr.State([])
mask_label_list = gr.State([])
mask_np_list_updated = gr.State([])
true = gr.State(True)
false = gr.State(False)
block_flag = gr.State(0)
num_tokens_global = gr.State(5)
with gr.Row():
gr.Markdown("""# D-Edit""")
with gr.Tab(label="1 Edit mask"):
with gr.Row():
with gr.Column():
canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
segment_button = gr.Button("1.1 Run segmentation")
segment_button.click(run_segmentation,
[canvas, block_flag] ,
[block_flag] )
text_button = gr.Button("Waiting 1.1 to complete")
text_button.click(load_image_ui,
[ false] ,
[image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
load_edit_button = gr.Button("Waiting 1.1 to complete")
load_edit_button.click(load_image_ui,
[ true] ,
[image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
show_segment = gr.Checkbox(label = "Waiting 1.1 to complete")
flag = gr.State(False)
show_segment.select(show_segmentation,
[image_loaded, segmentation, flag],
[canvas, flag])
def show_more_buttons():
return gr.Button("1.2 Load original masks"), gr.Button("1.2 Load edited masks") , gr.Checkbox(label = "Show Segmentation")
block_flag.change(show_more_buttons, [], [text_button,load_edit_button,show_segment ])
# mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
mask_np_list_updated = mask_np_list
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Optional)</p>""")
slider = gr.Slider(0, 20, step=1, interactive=True)
label = gr.Textbox()
slider.release(slider_release,
inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
outputs= [canvas, label]
)
add_button = gr.Button("Add")
add_button.click( edit_mask_add,
[canvas, image_loaded, slider, mask_np_list_updated] ,
[mask_np_list_updated, canvas]
)
save_button2 = gr.Button("Set and Save as edited masks")
save_button2.click( save_as_edit_mask,
[mask_np_list_updated, mask_label_list] ,
[] )
save_button = gr.Button("Set and Save as original masks")
save_button.click( save_as_orig_mask,
[mask_np_list_updated, mask_label_list] ,
[] )
back_button = gr.Button("Back to current seg")
back_button.click( load_mask_ui,
[] ,
[ mask_np_list_updated,mask_label_list] )
add_mask_button = gr.Button("Add new empty mask")
add_mask_button.click(add_mask,
[mask_np_list_updated, mask_label_list] ,
[mask_np_list_updated, mask_label_list] )
with gr.Tab(label="2 Optimization"):
with gr.Row():
with gr.Column():
txt_box = gr.Textbox("Click to start optimization...", interactive = False)
opt_flag = gr.State(0)
gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
num_tokens_global = num_tokens
embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True )
max_emb_train_steps = gr.Number(value="200", label="embedding optimization: Training steps", interactive= True )
diffusion_model_learning_rate = gr.Textbox(value="0.00005", label="UNet Optimization: Learning rate", interactive= True )
max_diffusion_train_steps = gr.Number(value="200", label="UNet Optimization: Learning rate: Training steps", interactive= True )
train_batch_size = gr.Number(value="5", label="Batch size", interactive= True )
gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True )
add_button = gr.Button("Run optimization")
def run_optimization_wrapper (
opt_flag,
num_tokens,
embedding_learning_rate ,
max_emb_train_steps ,
diffusion_model_learning_rate ,
max_diffusion_train_steps,
train_batch_size,
gradient_accumulation_steps
):
run_optimization = partial(
run_main,
num_tokens=int(num_tokens),
embedding_learning_rate = float(embedding_learning_rate),
max_emb_train_steps = int(max_emb_train_steps),
diffusion_model_learning_rate= float(diffusion_model_learning_rate),
max_diffusion_train_steps = int(max_diffusion_train_steps),
train_batch_size=int(train_batch_size),
gradient_accumulation_steps=int(gradient_accumulation_steps)
)
run_optimization()
return opt_flag+1
add_button.click(run_optimization_wrapper,
inputs = [
opt_flag,
num_tokens,
embedding_learning_rate ,
max_emb_train_steps ,
diffusion_model_learning_rate ,
max_diffusion_train_steps,
train_batch_size,
gradient_accumulation_steps
],
outputs = [opt_flag]
)
def change_text(txt_box):
return gr.Textbox("Optimization Finished!", interactive = False)
def change_text2(txt_box):
return gr.Textbox("Start optimization, check logs for progress...", interactive = False)
add_button.click(change_text2, txt_box, txt_box)
opt_flag.change(change_text, txt_box, txt_box)
with gr.Tab(label="3 Editing"):
with gr.Tab(label="3.1 Text-based editing"):
with gr.Row():
with gr.Column():
canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True)
# canvas_text_edit = gr.Gallery(label = "Edited results")
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
tgt_prompt = gr.Textbox(value="White bag", label="Editing: Text prompt", interactive= True )
tgt_index = gr.Number(value="0", label="Editing: Object index", interactive= True )
guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
num_sampling_steps = gr.Number(value="50", label="Editing: Sampling steps", interactive= True )
edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True )
strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
add_button = gr.Button("Run Editing")
def run_edit_text_wrapper(
num_tokens,
guidance_scale,
num_sampling_steps ,
strength ,
edge_thickness,
tgt_prompt ,
tgt_index
):
run_edit_text = partial(
run_main,
load_trained=True,
text=True,
num_tokens = int(num_tokens_global.value),
guidance_scale = float(guidance_scale),
num_sampling_steps = int(num_sampling_steps),
strength = float(strength),
edge_thickness = int(edge_thickness),
num_imgs = 1,
tgt_prompt = tgt_prompt,
tgt_index = int(tgt_index)
)
return run_edit_text()
add_button.click(run_edit_text_wrapper,
inputs = [num_tokens_global,
guidance_scale,
num_sampling_steps,
strength ,
edge_thickness,
tgt_prompt ,
tgt_index
],
outputs = [canvas_text_edit]
)
def load_pil_img():
from PIL import Image
return Image.open("example_tmp/text/out_text_0.png")
load_button = gr.Button("Load results")
load_button.click(load_pil_img,
inputs = [],
outputs = [canvas_text_edit]
)
demo.queue().launch(share=True, debug=True)