import os import copy from PIL import Image import matplotlib import numpy as np import gradio as gr from utils import load_mask, load_mask_edit from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean from pathlib import Path from PIL import Image from functools import partial from main import run_main LENGTH=512 #length of the square area displaying/editing images TRANSPARENCY = 150 # transparency of the mask in display def add_mask(mask_np_list_updated, mask_label_list): mask_new = np.zeros_like(mask_np_list_updated[0]) mask_np_list_updated.append(mask_new) mask_label_list.append("new") return mask_np_list_updated, mask_label_list def create_segmentation(mask_np_list): viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list)) segmentation = 0 for i, m in enumerate(mask_np_list): color = matplotlib.colors.to_rgb(viridis(i)) color_mat = np.ones_like(m) color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2) color_mat = color_mat * m[:,:,np.newaxis] segmentation += color_mat segmentation = Image.fromarray(np.uint8(segmentation*255)) return segmentation def load_mask_ui(input_folder="example_tmp",load_edit = False): if not load_edit: mask_list, mask_label_list = load_mask(input_folder) else: mask_list, mask_label_list = load_mask_edit(input_folder) mask_np_list = [] for m in mask_list: mask_np_list. append( m.cpu().numpy()) return mask_np_list, mask_label_list def load_image_ui(load_edit, input_folder="example_tmp"): try: for img_path in Path(input_folder).iterdir(): if img_path.name in ["img_512.png"]: image = Image.open(img_path) mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit) image = image.convert('RGB') segmentation = create_segmentation(mask_np_list) print("!!", len(mask_np_list)) return image, segmentation, mask_np_list, mask_label_list, image except: print("Image folder invalid: The folder should contain image.png") return None, None, None, None, None # def run_edit_text( # num_tokens, # num_sampling_steps, # strength, # edge_thickness, # tgt_prompt, # tgt_idx, # guidance_scale, # input_folder="example_tmp" # ): # subprocess.run(["python", # "main.py" , # "--text=True", # "--name={}".format(input_folder), # "--dpm={}".format("sd"), # "--resolution={}".format(512), # "--load_trained", # "--num_tokens={}".format(num_tokens), # "--seed={}".format(2024), # "--guidance_scale={}".format(guidance_scale), # "--num_sampling_step={}".format(num_sampling_steps), # "--strength={}".format(strength), # "--edge_thickness={}".format(edge_thickness), # "--num_imgs={}".format(2), # "--tgt_prompt={}".format(tgt_prompt) , # "--tgt_index={}".format(tgt_idx) # ]) # return Image.open(os.path.join(input_folder, "text", "out_text_0.png")) # def run_optimization( # num_tokens, # embedding_learning_rate, # max_emb_train_steps, # diffusion_model_learning_rate, # max_diffusion_train_steps, # train_batch_size, # gradient_accumulation_steps, # input_folder = "example_tmp" # ): # subprocess.run(["python", # "main.py" , # "--name={}".format(input_folder), # "--dpm={}".format("sd"), # "--resolution={}".format(512), # "--num_tokens={}".format(num_tokens), # "--embedding_learning_rate={}".format(embedding_learning_rate), # "--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate), # "--max_emb_train_steps={}".format(max_emb_train_steps), # "--max_diffusion_train_steps={}".format(max_diffusion_train_steps), # "--train_batch_size={}".format(train_batch_size), # "--gradient_accumulation_steps={}".format(gradient_accumulation_steps) # ]) # return def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128): backimg_solid_np = np.array(backimg) bimg = backimg.copy() fimg = foreimg.copy() fimg.putalpha(transparency) bimg.paste(fimg, (0,0), fimg) bimg_np = np.array(bimg) mask_np = mask_np[:,:,np.newaxis] try: new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np return Image.fromarray(new_img_np) except: import pdb; pdb.set_trace() def show_segmentation(image, segmentation, flag): if flag is False: flag = True mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8) image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY) return image_edit, flag else: flag = False return image,flag def edit_mask_add(canvas, image, idx, mask_np_list): mask_sel = mask_np_list[idx] mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.) mask_np_list_updated = [] for midx, m in enumerate(mask_np_list): if midx == idx: mask_np_list_updated.append(mask_union(mask_sel, mask_new)) else: mask_np_list_updated.append(m) priority_list = [0 for _ in range(len(mask_np_list_updated))] priority_list[idx] = 1 mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list) mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8) segmentation = create_segmentation(mask_np_list_updated) image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY) return mask_np_list_updated, image_edit def slider_release(index, image, mask_np_list_updated, mask_label_list): if index > len(mask_np_list_updated): return image, "out of range" else: mask_np = mask_np_list_updated[index] mask_label = mask_label_list[index] segmentation = create_segmentation(mask_np_list_updated) new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY) return new_image, mask_label def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"): try: assert np.all(sum(mask_np_list_updated)==1) except: print("please check mask") # plt.imsave( "out_mask.png", mask_list_edit[0]) import pdb; pdb.set_trace() for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)): # np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask ) np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask ) savepath = os.path.join(input_folder, "seg_current.png") visualize_mask_list_clean(mask_np_list_updated, savepath) def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"): try: assert np.all(sum(mask_np_list_updated)==1) except: print("please check mask") # plt.imsave( "out_mask.png", mask_list_edit[0]) import pdb; pdb.set_trace() for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)): np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask) savepath = os.path.join(input_folder, "seg_edited.png") visualize_mask_list_clean(mask_np_list_updated, savepath) import shutil if os.path.isdir("./example_tmp"): shutil.rmtree("./example_tmp") from segment import run_segmentation with gr.Blocks() as demo: image = gr.State() # store mask image_loaded = gr.State() segmentation = gr.State() mask_np_list = gr.State([]) mask_label_list = gr.State([]) mask_np_list_updated = gr.State([]) true = gr.State(True) false = gr.State(False) block_flag = gr.State(0) num_tokens_global = gr.State(5) with gr.Row(): gr.Markdown("""# D-Edit""") with gr.Tab(label="1 Edit mask"): with gr.Row(): with gr.Column(): canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True) segment_button = gr.Button("1.1 Run segmentation") segment_button.click(run_segmentation, [canvas, block_flag] , [block_flag] ) text_button = gr.Button("Waiting 1.1 to complete") text_button.click(load_image_ui, [ false] , [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] ) load_edit_button = gr.Button("Waiting 1.1 to complete") load_edit_button.click(load_image_ui, [ true] , [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] ) show_segment = gr.Checkbox(label = "Waiting 1.1 to complete") flag = gr.State(False) show_segment.select(show_segmentation, [image_loaded, segmentation, flag], [canvas, flag]) def show_more_buttons(): return gr.Button("1.2 Load original masks"), gr.Button("1.2 Load edited masks") , gr.Checkbox(label = "Show Segmentation") block_flag.change(show_more_buttons, [], [text_button,load_edit_button,show_segment ]) # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!! mask_np_list_updated = mask_np_list with gr.Column(): gr.Markdown("""

Edit Mask (Optional)

""") slider = gr.Slider(0, 20, step=1, interactive=True) label = gr.Textbox() slider.release(slider_release, inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list], outputs= [canvas, label] ) add_button = gr.Button("Add") add_button.click( edit_mask_add, [canvas, image_loaded, slider, mask_np_list_updated] , [mask_np_list_updated, canvas] ) save_button2 = gr.Button("Set and Save as edited masks") save_button2.click( save_as_edit_mask, [mask_np_list_updated, mask_label_list] , [] ) save_button = gr.Button("Set and Save as original masks") save_button.click( save_as_orig_mask, [mask_np_list_updated, mask_label_list] , [] ) back_button = gr.Button("Back to current seg") back_button.click( load_mask_ui, [] , [ mask_np_list_updated,mask_label_list] ) add_mask_button = gr.Button("Add new empty mask") add_mask_button.click(add_mask, [mask_np_list_updated, mask_label_list] , [mask_np_list_updated, mask_label_list] ) with gr.Tab(label="2 Optimization"): with gr.Row(): with gr.Column(): txt_box = gr.Textbox("Click to start optimization...", interactive = False) opt_flag = gr.State(0) gr.Markdown("""

Optimization settings (SD)

""") num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True) num_tokens_global = num_tokens embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True ) max_emb_train_steps = gr.Number(value="200", label="embedding optimization: Training steps", interactive= True ) diffusion_model_learning_rate = gr.Textbox(value="0.00005", label="UNet Optimization: Learning rate", interactive= True ) max_diffusion_train_steps = gr.Number(value="200", label="UNet Optimization: Learning rate: Training steps", interactive= True ) train_batch_size = gr.Number(value="5", label="Batch size", interactive= True ) gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True ) add_button = gr.Button("Run optimization") def run_optimization_wrapper ( opt_flag, num_tokens, embedding_learning_rate , max_emb_train_steps , diffusion_model_learning_rate , max_diffusion_train_steps, train_batch_size, gradient_accumulation_steps ): run_optimization = partial( run_main, num_tokens=int(num_tokens), embedding_learning_rate = float(embedding_learning_rate), max_emb_train_steps = int(max_emb_train_steps), diffusion_model_learning_rate= float(diffusion_model_learning_rate), max_diffusion_train_steps = int(max_diffusion_train_steps), train_batch_size=int(train_batch_size), gradient_accumulation_steps=int(gradient_accumulation_steps) ) run_optimization() return opt_flag+1 add_button.click(run_optimization_wrapper, inputs = [ opt_flag, num_tokens, embedding_learning_rate , max_emb_train_steps , diffusion_model_learning_rate , max_diffusion_train_steps, train_batch_size, gradient_accumulation_steps ], outputs = [opt_flag] ) def change_text(txt_box): return gr.Textbox("Optimization Finished!", interactive = False) def change_text2(txt_box): return gr.Textbox("Start optimization, check logs for progress...", interactive = False) add_button.click(change_text2, txt_box, txt_box) opt_flag.change(change_text, txt_box, txt_box) with gr.Tab(label="3 Editing"): with gr.Tab(label="3.1 Text-based editing"): with gr.Row(): with gr.Column(): canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True) # canvas_text_edit = gr.Gallery(label = "Edited results") with gr.Column(): gr.Markdown("""

Editing setting (SD)

""") tgt_prompt = gr.Textbox(value="White bag", label="Editing: Text prompt", interactive= True ) tgt_index = gr.Number(value="0", label="Editing: Object index", interactive= True ) guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True ) num_sampling_steps = gr.Number(value="50", label="Editing: Sampling steps", interactive= True ) edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True ) strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True ) add_button = gr.Button("Run Editing") def run_edit_text_wrapper( num_tokens, guidance_scale, num_sampling_steps , strength , edge_thickness, tgt_prompt , tgt_index ): run_edit_text = partial( run_main, load_trained=True, text=True, num_tokens = int(num_tokens_global.value), guidance_scale = float(guidance_scale), num_sampling_steps = int(num_sampling_steps), strength = float(strength), edge_thickness = int(edge_thickness), num_imgs = 1, tgt_prompt = tgt_prompt, tgt_index = int(tgt_index) ) return run_edit_text() add_button.click(run_edit_text_wrapper, inputs = [num_tokens_global, guidance_scale, num_sampling_steps, strength , edge_thickness, tgt_prompt , tgt_index ], outputs = [canvas_text_edit] ) def load_pil_img(): from PIL import Image return Image.open("example_tmp/text/out_text_0.png") load_button = gr.Button("Load results") load_button.click(load_pil_img, inputs = [], outputs = [canvas_text_edit] ) demo.queue().launch(share=True, debug=True)