import os import torch import numpy as np import argparse from peft import LoraConfig from pipeline_dedit_sdxl import DEditSDXLPipeline from pipeline_dedit_sd import DEditSDPipeline from utils import load_image, load_mask, load_mask_edit from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys def run_main( name="example_tmp", name_2=None, dpm="sd", resolution=512, seed=42, embedding_learning_rate=1e-4, max_emb_train_steps=200, diffusion_model_learning_rate=5e-5, max_diffusion_train_steps=200, train_batch_size=1, gradient_accumulation_steps=1, num_tokens=1, load_trained=False , num_sampling_steps=50, guidance_scale= 3 , strength=0.8, train_full_lora=False , lora_rank=4, lora_alpha=4, prompt_auxin_list = None, prompt_auxin_idx_list= None, load_edited_mask=False, load_edited_processed_mask=False, edge_thickness=20, num_imgs= 1 , active_mask_list = None, tgt_index=None, recon=False , recon_an_item=False, recon_prompt=None, text=False, tgt_prompt=None, image=False , src_index=None, tgt_name=None, move_resize=False , tgt_indices_list=None, delta_x_list=None, delta_y_list=None, priority_list=None, force_mask_remain=None, resize_list=None, remove=False, load_edited_removemask=False ): torch.cuda.manual_seed_all(seed) torch.manual_seed(seed) base_input_folder = "." base_output_folder = "." input_folder = os.path.join(base_input_folder, name) mask_list, mask_label_list = load_mask(input_folder) assert mask_list[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution) try: image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(resolution) ), size = resolution) except: image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(resolution) ), size = resolution) if image: input_folder_2 = os.path.join(base_input_folder, name_2) mask_list_2, mask_label_list_2 = load_mask(input_folder_2) assert mask_list_2[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution) try: image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.png".format(resolution) ), size = resolution) except: image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.jpg".format(resolution) ), size = resolution) output_dir = os.path.join(base_output_folder, name + "_" + name_2) os.makedirs(output_dir, exist_ok = True) else: output_dir = os.path.join(base_output_folder, name) os.makedirs(output_dir, exist_ok = True) if dpm == "sd": if image: pipe = DEditSDPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = resolution, num_tokens = num_tokens) else: pipe = DEditSDPipeline(mask_list, mask_label_list, resolution = resolution, num_tokens = num_tokens) elif dpm == "sdxl": if image: pipe = DEditSDXLPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = resolution, num_tokens = num_tokens) else: pipe = DEditSDXLPipeline(mask_list, mask_label_list, resolution = resolution, num_tokens = num_tokens) else: raise NotImplementedError set_string_list = pipe.set_string_list if prompt_auxin_list is not None: for auxin_idx, auxin_prompt in zip(prompt_auxin_idx_list, prompt_auxin_list): set_string_list[auxin_idx] = auxin_prompt.replace("*", set_string_list[auxin_idx] ) print(set_string_list) if image: set_string_list_2 = pipe.set_string_list_2 print(set_string_list_2) if load_trained: unet_save_path = os.path.join(output_dir, "unet.pt") unet_state_dict = torch.load(unet_save_path) text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt") text_encoder1_state_dict = torch.load(text_encoder1_save_path) if dpm == "sdxl": text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt") text_encoder2_state_dict = torch.load(text_encoder2_save_path) if 'lora' in ''.join(unet_state_dict.keys()): unet_lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) pipe.unet.add_adapter(unet_lora_config) pipe.unet.load_state_dict(unet_state_dict) pipe.text_encoder.load_state_dict(text_encoder1_state_dict) if dpm == "sdxl": pipe.text_encoder_2.load_state_dict(text_encoder2_state_dict) else: if image: pipe.mask_list = [m.cuda() for m in pipe.mask_list] pipe.mask_list_2 = [m.cuda() for m in pipe.mask_list_2] pipe.train_emb_2imgs( image_gt, image_gt_2, set_string_list, set_string_list_2, gradient_accumulation_steps = gradient_accumulation_steps, embedding_learning_rate = embedding_learning_rate, max_emb_train_steps = max_emb_train_steps, train_batch_size = train_batch_size, ) pipe.train_model_2imgs( image_gt, image_gt_2, set_string_list, set_string_list_2, gradient_accumulation_steps = gradient_accumulation_steps, max_diffusion_train_steps = max_diffusion_train_steps, diffusion_model_learning_rate = diffusion_model_learning_rate , train_batch_size =train_batch_size, train_full_lora = train_full_lora, lora_rank = lora_rank, lora_alpha = lora_alpha ) else: pipe.mask_list = [m.cuda() for m in pipe.mask_list] pipe.train_emb( image_gt, set_string_list, gradient_accumulation_steps = gradient_accumulation_steps, embedding_learning_rate = embedding_learning_rate, max_emb_train_steps = max_emb_train_steps, train_batch_size = train_batch_size, ) pipe.train_model( image_gt, set_string_list, gradient_accumulation_steps = gradient_accumulation_steps, max_diffusion_train_steps = max_diffusion_train_steps, diffusion_model_learning_rate = diffusion_model_learning_rate , train_batch_size = train_batch_size, train_full_lora = train_full_lora, lora_rank = lora_rank, lora_alpha = lora_alpha ) unet_save_path = os.path.join(output_dir, "unet.pt") torch.save(pipe.unet.state_dict(),unet_save_path ) text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt") torch.save(pipe.text_encoder.state_dict(), text_encoder1_save_path) if dpm == "sdxl": text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt") torch.save(pipe.text_encoder_2.state_dict(), text_encoder2_save_path ) if recon: output_dir = os.path.join(output_dir, "recon") os.makedirs(output_dir, exist_ok = True) if recon_an_item: mask_list = [torch.from_numpy(np.ones_like(mask_list[0].numpy()))] tgt_string = set_string_list[tgt_index] tgt_string = recon_prompt.replace("*", tgt_string) set_string_list = [tgt_string] print(set_string_list) save_path = os.path.join(output_dir, "out_recon.png") x_np = pipe.inference_with_mask( save_path, guidance_scale = guidance_scale, num_sampling_steps = num_sampling_steps, seed = seed, num_imgs = num_imgs, set_string_list = set_string_list, mask_list = mask_list ) if text: print("*** Text-guided editing ") output_dir = os.path.join(output_dir, "text") os.makedirs(output_dir, exist_ok = True) save_path = os.path.join(output_dir, "out_text.png") set_string_list[tgt_index] = tgt_prompt mask_active = torch.zeros_like(mask_list[0]) mask_active = mask_union_torch(mask_active, mask_list[tgt_index]) if active_mask_list is not None: for midx in active_mask_list: mask_active = mask_union_torch(mask_active, mask_list[midx]) if load_edited_mask: mask_list_edited, mask_label_list_edited = load_mask_edit(input_folder) mask_diff = get_mask_difference_torch(mask_list_edited, mask_list) mask_active = mask_union_torch(mask_active, mask_diff) mask_list = mask_list_edited save_path = os.path.join(output_dir, "out_textEdited.png") mask_hard = mask_substract_torch(torch.ones_like(mask_list[0]), mask_active) mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = edge_thickness) mask_hard = mask_substract_torch(mask_hard, mask_soft) pipe.inference_with_mask( save_path, orig_image = image_gt, set_string_list = set_string_list, guidance_scale = guidance_scale, strength = strength, num_imgs = num_imgs, mask_hard= mask_hard, mask_soft = mask_soft, mask_list = mask_list, seed = seed, num_sampling_steps = num_sampling_steps ) if remove: output_dir = os.path.join(output_dir, "remove") save_path = os.path.join(output_dir, "out_remove.png") os.makedirs(output_dir, exist_ok = True) mask_active = torch.zeros_like(mask_list[0]) if load_edited_mask: mask_list_edited, _ = load_mask_edit(input_folder) mask_diff = get_mask_difference_torch(mask_list_edited, mask_list) mask_active = mask_union_torch(mask_active, mask_diff) mask_list = mask_list_edited if load_edited_processed_mask: # manually edit or draw masks after removing one index, then load mask_list_processed, _ = load_mask_edit(output_dir) mask_remain = get_mask_difference_torch(mask_list_processed, mask_list) else: # generate masks after removing one index, using nearest neighbor algorithm mask_list_processed, mask_remain = process_mask_remove_torch(mask_list, tgt_index) save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask") visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_removed.png")) check_cover_all_torch(*mask_list_processed) mask_active = mask_union_torch(mask_active, mask_remain) if active_mask_list is not None: for midx in active_mask_list: mask_active = mask_union_torch(mask_active, mask_list[midx]) mask_hard = 1 - mask_active mask_soft = create_outer_edge_mask_torch(mask_remain, edge_thickness = edge_thickness) mask_hard = mask_substract_torch(mask_hard, mask_soft) pipe.inference_with_mask( save_path, orig_image = image_gt, guidance_scale = guidance_scale, strength = strength, num_imgs = num_imgs, mask_hard= mask_hard, mask_soft = mask_soft, mask_list = mask_list_processed, seed = seed, num_sampling_steps = num_sampling_steps ) if image: output_dir = os.path.join(output_dir, "image") save_path = os.path.join(output_dir, "out_image.png") os.makedirs(output_dir, exist_ok = True) mask_active = torch.zeros_like(mask_list[0]) if None not in (tgt_name, src_index, tgt_index): if tgt_name == name: set_string_list_tgt = set_string_list set_string_list_src = set_string_list_2 image_tgt = image_gt if load_edited_mask: mask_list_edited, _ = load_mask_edit(input_folder) mask_diff = get_mask_difference_torch(mask_list_edited, mask_list) mask_active = mask_union_torch(mask_active, mask_diff) mask_list = mask_list_edited save_path = os.path.join(output_dir, "out_imageEdited.png") mask_list_tgt = mask_list elif tgt_name == name_2: set_string_list_tgt = set_string_list_2 set_string_list_src = set_string_list image_tgt = image_gt_2 if load_edited_mask: mask_list_2_edited, _ = load_mask_edit(input_folder_2) mask_diff = get_mask_difference_torch(mask_list_2_edited, mask_list_2) mask_active = mask_union_torch(mask_active, mask_diff) mask_list_2 = mask_list_2_edited save_path = os.path.join(output_dir, "out_imageEdited.png") mask_list_tgt = mask_list_2 else: exit("tgt_name should be either name or name_2") set_string_list_tgt[tgt_index] = set_string_list_src[src_index] mask_active = mask_list_tgt[tgt_index] mask_frozen = (1-mask_active.float()).to(mask_active.device) mask_soft = create_outer_edge_mask_torch(mask_active.cpu(), edge_thickness = edge_thickness) mask_hard = mask_substract_torch(mask_frozen.cpu(), mask_soft.cpu()) mask_list_tgt = [m.cuda() for m in mask_list_tgt] pipe.inference_with_mask( save_path, set_string_list = set_string_list_tgt, mask_list = mask_list_tgt, guidance_scale = guidance_scale, num_sampling_steps = num_sampling_steps, mask_hard = mask_hard.cuda(), mask_soft = mask_soft.cuda(), num_imgs = num_imgs, orig_image = image_tgt, strength = strength, ) if move_resize: output_dir = os.path.join(output_dir, "move_resize") os.makedirs(output_dir, exist_ok = True) save_path = os.path.join(output_dir, "out_moveresize.png") mask_active = torch.zeros_like(mask_list[0]) if load_edited_mask: mask_list_edited, _ = load_mask_edit(input_folder) mask_diff = get_mask_difference_torch(mask_list_edited, mask_list) mask_active = mask_union_torch(mask_active, mask_diff) mask_list = mask_list_edited # save_path = os.path.join(output_dir, "out_moveresizeEdited.png") if load_edited_processed_mask: mask_list_processed, _ = load_mask_edit(output_dir) mask_remain = get_mask_difference_torch(mask_list_processed, mask_list) else: mask_list_processed, mask_remain = process_mask_move_torch( mask_list, tgt_indices_list, delta_x_list, delta_y_list, priority_list, force_mask_remain = force_mask_remain, resize_list = resize_list ) save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask") visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_move_resize.png")) active_idxs = tgt_indices_list mask_active = mask_union_torch(mask_active, *[m for midx, m in enumerate(mask_list_processed) if midx in active_idxs]) mask_active = mask_union_torch(mask_remain, mask_active) if active_mask_list is not None: for midx in active_mask_list: mask_active = mask_union_torch(mask_active, mask_list_processed[midx]) mask_frozen =(1 - mask_active.float()) mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = edge_thickness) mask_hard = mask_substract_torch(mask_frozen, mask_soft) check_mask_overlap_torch(mask_hard, mask_soft) pipe.inference_with_mask( save_path, strength = strength, orig_image = image_gt, guidance_scale = guidance_scale, num_sampling_steps = num_sampling_steps, num_imgs = num_imgs, mask_hard= mask_hard, mask_soft = mask_soft, mask_list = mask_list_processed, seed = seed )