d-edit / main.py
niulx's picture
Update main.py
39680c0 verified
raw
history blame contribute delete
No virus
17.6 kB
import os
import spaces
import torch
import numpy as np
import argparse
from peft import LoraConfig
from pipeline_dedit_sdxl import DEditSDXLPipeline
from pipeline_dedit_sd import DEditSDPipeline
from utils import load_image, load_mask, load_mask_edit
from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch
from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys
@spaces.GPU(duration=45)
def run_main(
name="example_tmp",
name_2=None,
mask_np_list=None,
mask_label_list=None,
image_gt=None,
dpm="sd",
resolution=512,
seed=42,
embedding_learning_rate=1e-4,
max_emb_train_steps=200,
diffusion_model_learning_rate=5e-5,
max_diffusion_train_steps=200,
train_batch_size=1,
gradient_accumulation_steps=1,
num_tokens=1,
load_trained=False ,
num_sampling_steps=50,
guidance_scale= 3 ,
strength=0.8,
train_full_lora=False ,
lora_rank=4,
lora_alpha=4,
prompt_auxin_list = None,
prompt_auxin_idx_list= None,
load_edited_mask=False,
load_edited_processed_mask=False,
edge_thickness=20,
num_imgs= 1 ,
active_mask_list = None,
tgt_index=None,
recon=False ,
recon_an_item=False,
recon_prompt=None,
text=False,
tgt_prompt=None,
image=False ,
src_index=None,
tgt_name=None,
move_resize=False ,
tgt_indices_list=None,
delta_x_list=None,
delta_y_list=None,
priority_list=None,
force_mask_remain=None,
resize_list=None,
remove=False,
load_edited_removemask=False
):
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
base_input_folder = "."
base_output_folder = "."
input_folder = os.path.join(base_input_folder, name)
mask_list = []
for mask_np in mask_np_list:
mask = torch.from_numpy(mask_np.astype(np.uint8))
mask_list.append(mask)
#mask_list, mask_label_list = load_mask(input_folder)
assert mask_list[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution)
#try:
# image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(resolution) ), size = resolution)
#except:
# image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(resolution) ), size = resolution)
if image:
input_folder_2 = os.path.join(base_input_folder, name_2)
mask_list_2, mask_label_list_2 = load_mask(input_folder_2)
assert mask_list_2[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution)
try:
image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.png".format(resolution) ), size = resolution)
except:
image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.jpg".format(resolution) ), size = resolution)
output_dir = os.path.join(base_output_folder, name + "_" + name_2)
os.makedirs(output_dir, exist_ok = True)
else:
output_dir = os.path.join(base_output_folder, name)
os.makedirs(output_dir, exist_ok = True)
if dpm == "sd":
if image:
pipe = DEditSDPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = resolution, num_tokens = num_tokens)
else:
pipe = DEditSDPipeline(mask_list, mask_label_list, resolution = resolution, num_tokens = num_tokens)
elif dpm == "sdxl":
if image:
pipe = DEditSDXLPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = resolution, num_tokens = num_tokens)
else:
pipe = DEditSDXLPipeline(mask_list, mask_label_list, resolution = resolution, num_tokens = num_tokens)
else:
raise NotImplementedError
set_string_list = pipe.set_string_list
if prompt_auxin_list is not None:
for auxin_idx, auxin_prompt in zip(prompt_auxin_idx_list, prompt_auxin_list):
set_string_list[auxin_idx] = auxin_prompt.replace("*", set_string_list[auxin_idx] )
print(set_string_list)
if image:
set_string_list_2 = pipe.set_string_list_2
print(set_string_list_2)
if load_trained:
unet_save_path = os.path.join(output_dir, "unet.pt")
unet_state_dict = torch.load(unet_save_path)
text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
text_encoder1_state_dict = torch.load(text_encoder1_save_path)
if dpm == "sdxl":
text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
text_encoder2_state_dict = torch.load(text_encoder2_save_path)
if 'lora' in ''.join(unet_state_dict.keys()):
unet_lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
pipe.unet.add_adapter(unet_lora_config)
pipe.unet.load_state_dict(unet_state_dict)
pipe.text_encoder.load_state_dict(text_encoder1_state_dict)
if dpm == "sdxl":
pipe.text_encoder_2.load_state_dict(text_encoder2_state_dict)
else:
if image:
pipe.mask_list = [m.cuda() for m in pipe.mask_list]
pipe.mask_list_2 = [m.cuda() for m in pipe.mask_list_2]
pipe.train_emb_2imgs(
image_gt,
image_gt_2,
set_string_list,
set_string_list_2,
gradient_accumulation_steps = gradient_accumulation_steps,
embedding_learning_rate = embedding_learning_rate,
max_emb_train_steps = max_emb_train_steps,
train_batch_size = train_batch_size,
)
pipe.train_model_2imgs(
image_gt,
image_gt_2,
set_string_list,
set_string_list_2,
gradient_accumulation_steps = gradient_accumulation_steps,
max_diffusion_train_steps = max_diffusion_train_steps,
diffusion_model_learning_rate = diffusion_model_learning_rate ,
train_batch_size =train_batch_size,
train_full_lora = train_full_lora,
lora_rank = lora_rank,
lora_alpha = lora_alpha
)
else:
pipe.mask_list = [m.cuda() for m in pipe.mask_list]
pipe.train_emb(
image_gt,
set_string_list,
gradient_accumulation_steps = gradient_accumulation_steps,
embedding_learning_rate = embedding_learning_rate,
max_emb_train_steps = max_emb_train_steps,
train_batch_size = train_batch_size,
)
pipe.train_model(
image_gt,
set_string_list,
gradient_accumulation_steps = gradient_accumulation_steps,
max_diffusion_train_steps = max_diffusion_train_steps,
diffusion_model_learning_rate = diffusion_model_learning_rate ,
train_batch_size = train_batch_size,
train_full_lora = train_full_lora,
lora_rank = lora_rank,
lora_alpha = lora_alpha
)
unet_save_path = os.path.join(output_dir, "unet.pt")
torch.save(pipe.unet.state_dict(),unet_save_path )
text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
torch.save(pipe.text_encoder.state_dict(), text_encoder1_save_path)
if dpm == "sdxl":
text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
torch.save(pipe.text_encoder_2.state_dict(), text_encoder2_save_path )
if recon:
output_dir = os.path.join(output_dir, "recon")
os.makedirs(output_dir, exist_ok = True)
if recon_an_item:
mask_list = [torch.from_numpy(np.ones_like(mask_list[0].numpy()))]
tgt_string = set_string_list[tgt_index]
tgt_string = recon_prompt.replace("*", tgt_string)
set_string_list = [tgt_string]
print(set_string_list)
save_path = os.path.join(output_dir, "out_recon.png")
x_np = pipe.inference_with_mask(
save_path,
guidance_scale = guidance_scale,
num_sampling_steps = num_sampling_steps,
seed = seed,
num_imgs = num_imgs,
set_string_list = set_string_list,
mask_list = mask_list
)
if text:
print("*** Text-guided editing ")
output_dir = os.path.join(output_dir, "text")
os.makedirs(output_dir, exist_ok = True)
save_path = os.path.join(output_dir, "out_text.png")
set_string_list[tgt_index] = tgt_prompt
mask_active = torch.zeros_like(mask_list[0])
mask_active = mask_union_torch(mask_active, mask_list[tgt_index])
if active_mask_list is not None:
for midx in active_mask_list:
mask_active = mask_union_torch(mask_active, mask_list[midx])
if load_edited_mask:
mask_list_edited, mask_label_list_edited = load_mask_edit(input_folder)
mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
mask_active = mask_union_torch(mask_active, mask_diff)
mask_list = mask_list_edited
save_path = os.path.join(output_dir, "out_textEdited.png")
mask_hard = mask_substract_torch(torch.ones_like(mask_list[0]), mask_active)
mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = edge_thickness)
mask_hard = mask_substract_torch(mask_hard, mask_soft)
pipe.inference_with_mask(
save_path,
orig_image = image_gt,
set_string_list = set_string_list,
guidance_scale = guidance_scale,
strength = strength,
num_imgs = num_imgs,
mask_hard= mask_hard,
mask_soft = mask_soft,
mask_list = mask_list,
seed = seed,
num_sampling_steps = num_sampling_steps
)
if remove:
output_dir = os.path.join(output_dir, "remove")
save_path = os.path.join(output_dir, "out_remove.png")
os.makedirs(output_dir, exist_ok = True)
mask_active = torch.zeros_like(mask_list[0])
if load_edited_mask:
mask_list_edited, _ = load_mask_edit(input_folder)
mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
mask_active = mask_union_torch(mask_active, mask_diff)
mask_list = mask_list_edited
if load_edited_processed_mask:
# manually edit or draw masks after removing one index, then load
mask_list_processed, _ = load_mask_edit(output_dir)
mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
else:
# generate masks after removing one index, using nearest neighbor algorithm
mask_list_processed, mask_remain = process_mask_remove_torch(mask_list, tgt_index)
save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_removed.png"))
check_cover_all_torch(*mask_list_processed)
mask_active = mask_union_torch(mask_active, mask_remain)
if active_mask_list is not None:
for midx in active_mask_list:
mask_active = mask_union_torch(mask_active, mask_list[midx])
mask_hard = 1 - mask_active
mask_soft = create_outer_edge_mask_torch(mask_remain, edge_thickness = edge_thickness)
mask_hard = mask_substract_torch(mask_hard, mask_soft)
pipe.inference_with_mask(
save_path,
orig_image = image_gt,
guidance_scale = guidance_scale,
strength = strength,
num_imgs = num_imgs,
mask_hard= mask_hard,
mask_soft = mask_soft,
mask_list = mask_list_processed,
seed = seed,
num_sampling_steps = num_sampling_steps
)
if image:
output_dir = os.path.join(output_dir, "image")
save_path = os.path.join(output_dir, "out_image.png")
os.makedirs(output_dir, exist_ok = True)
mask_active = torch.zeros_like(mask_list[0])
if None not in (tgt_name, src_index, tgt_index):
if tgt_name == name:
set_string_list_tgt = set_string_list
set_string_list_src = set_string_list_2
image_tgt = image_gt
if load_edited_mask:
mask_list_edited, _ = load_mask_edit(input_folder)
mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
mask_active = mask_union_torch(mask_active, mask_diff)
mask_list = mask_list_edited
save_path = os.path.join(output_dir, "out_imageEdited.png")
mask_list_tgt = mask_list
elif tgt_name == name_2:
set_string_list_tgt = set_string_list_2
set_string_list_src = set_string_list
image_tgt = image_gt_2
if load_edited_mask:
mask_list_2_edited, _ = load_mask_edit(input_folder_2)
mask_diff = get_mask_difference_torch(mask_list_2_edited, mask_list_2)
mask_active = mask_union_torch(mask_active, mask_diff)
mask_list_2 = mask_list_2_edited
save_path = os.path.join(output_dir, "out_imageEdited.png")
mask_list_tgt = mask_list_2
else:
exit("tgt_name should be either name or name_2")
set_string_list_tgt[tgt_index] = set_string_list_src[src_index]
mask_active = mask_list_tgt[tgt_index]
mask_frozen = (1-mask_active.float()).to(mask_active.device)
mask_soft = create_outer_edge_mask_torch(mask_active.cpu(), edge_thickness = edge_thickness)
mask_hard = mask_substract_torch(mask_frozen.cpu(), mask_soft.cpu())
mask_list_tgt = [m.cuda() for m in mask_list_tgt]
pipe.inference_with_mask(
save_path,
set_string_list = set_string_list_tgt,
mask_list = mask_list_tgt,
guidance_scale = guidance_scale,
num_sampling_steps = num_sampling_steps,
mask_hard = mask_hard.cuda(),
mask_soft = mask_soft.cuda(),
num_imgs = num_imgs,
orig_image = image_tgt,
strength = strength,
)
if move_resize:
output_dir = os.path.join(output_dir, "move_resize")
os.makedirs(output_dir, exist_ok = True)
save_path = os.path.join(output_dir, "out_moveresize.png")
mask_active = torch.zeros_like(mask_list[0])
if load_edited_mask:
mask_list_edited, _ = load_mask_edit(input_folder)
mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
mask_active = mask_union_torch(mask_active, mask_diff)
mask_list = mask_list_edited
# save_path = os.path.join(output_dir, "out_moveresizeEdited.png")
if load_edited_processed_mask:
mask_list_processed, _ = load_mask_edit(output_dir)
mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
else:
mask_list_processed, mask_remain = process_mask_move_torch(
mask_list,
tgt_indices_list,
delta_x_list,
delta_y_list, priority_list,
force_mask_remain = force_mask_remain,
resize_list = resize_list
)
save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_move_resize.png"))
active_idxs = tgt_indices_list
mask_active = mask_union_torch(mask_active, *[m for midx, m in enumerate(mask_list_processed) if midx in active_idxs])
mask_active = mask_union_torch(mask_remain, mask_active)
if active_mask_list is not None:
for midx in active_mask_list:
mask_active = mask_union_torch(mask_active, mask_list_processed[midx])
mask_frozen =(1 - mask_active.float())
mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = edge_thickness)
mask_hard = mask_substract_torch(mask_frozen, mask_soft)
check_mask_overlap_torch(mask_hard, mask_soft)
pipe.inference_with_mask(
save_path,
strength = strength,
orig_image = image_gt,
guidance_scale = guidance_scale,
num_sampling_steps = num_sampling_steps,
num_imgs = num_imgs,
mask_hard= mask_hard,
mask_soft = mask_soft,
mask_list = mask_list_processed,
seed = seed
)