|
import os
|
|
|
|
import cv2
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from basicsr.utils import tensor2img
|
|
from pytorch_lightning import seed_everything
|
|
from torch import autocast
|
|
from torchvision.io import read_image
|
|
|
|
from ldm.inference_base import (diffusion_inference, get_adapters, get_base_argument_parser, get_sd_models)
|
|
from ldm.modules.extra_condition import api
|
|
from ldm.modules.extra_condition.api import (ExtraCondition, get_adapter_feature, get_cond_model)
|
|
from ldm.util import fix_cond_shapes
|
|
|
|
|
|
from masactrl.masactrl_utils import regiter_attention_editor_ldm
|
|
from masactrl.masactrl import MutualSelfAttentionControl
|
|
from masactrl.masactrl import MutualSelfAttentionControlMask
|
|
from masactrl.masactrl import MutualSelfAttentionControlMaskAuto
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
def main():
|
|
supported_cond = [e.name for e in ExtraCondition]
|
|
parser = get_base_argument_parser()
|
|
parser.add_argument(
|
|
'--which_cond',
|
|
type=str,
|
|
required=True,
|
|
choices=supported_cond,
|
|
help='which condition modality you want to test',
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--cond_path_src",
|
|
type=str,
|
|
default=None,
|
|
help="the condition image path to synthesize the source image",
|
|
)
|
|
parser.add_argument(
|
|
"--prompt_src",
|
|
type=str,
|
|
default=None,
|
|
help="the prompt to synthesize the source image",
|
|
)
|
|
parser.add_argument(
|
|
"--src_img_path",
|
|
type=str,
|
|
default=None,
|
|
help="the input real source image path"
|
|
)
|
|
parser.add_argument(
|
|
"--start_code_path",
|
|
type=str,
|
|
default=None,
|
|
help="the inverted start code path to synthesize the source image",
|
|
)
|
|
parser.add_argument(
|
|
"--masa_step",
|
|
type=int,
|
|
default=4,
|
|
help="the starting step for MasaCtrl",
|
|
)
|
|
parser.add_argument(
|
|
"--masa_layer",
|
|
type=int,
|
|
default=10,
|
|
help="the starting layer for MasaCtrl",
|
|
)
|
|
|
|
opt = parser.parse_args()
|
|
which_cond = opt.which_cond
|
|
if opt.outdir is None:
|
|
opt.outdir = f'outputs/test-{which_cond}'
|
|
os.makedirs(opt.outdir, exist_ok=True)
|
|
if opt.resize_short_edge is None:
|
|
print(f"you don't specify the resize_shot_edge, so the maximum resolution is set to {opt.max_resolution}")
|
|
opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
if os.path.isdir(opt.cond_path):
|
|
image_paths = [os.path.join(opt.cond_path, f) for f in os.listdir(opt.cond_path)]
|
|
else:
|
|
image_paths = [opt.cond_path]
|
|
print(image_paths)
|
|
|
|
|
|
sd_model, sampler = get_sd_models(opt)
|
|
adapter = get_adapters(opt, getattr(ExtraCondition, which_cond))
|
|
cond_model = None
|
|
if opt.cond_inp_type == 'image':
|
|
cond_model = get_cond_model(opt, getattr(ExtraCondition, which_cond))
|
|
|
|
process_cond_module = getattr(api, f'get_cond_{which_cond}')
|
|
|
|
|
|
STEP = opt.masa_step if opt.masa_step is not None else 4
|
|
LAYER = opt.masa_layer if opt.masa_layer is not None else 10
|
|
|
|
|
|
with torch.inference_mode(), \
|
|
sd_model.ema_scope(), \
|
|
autocast('cuda'):
|
|
for test_idx, cond_path in enumerate(image_paths):
|
|
seed_everything(opt.seed)
|
|
for v_idx in range(opt.n_samples):
|
|
|
|
if opt.cond_path_src:
|
|
cond_src = process_cond_module(opt, opt.cond_path_src, opt.cond_inp_type, cond_model)
|
|
cond = process_cond_module(opt, cond_path, opt.cond_inp_type, cond_model)
|
|
|
|
base_count = len(os.listdir(opt.outdir)) // 2
|
|
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_{which_cond}.png'), tensor2img(cond))
|
|
if opt.cond_path_src:
|
|
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_{which_cond}_src.png'), tensor2img(cond_src))
|
|
|
|
adapter_features, append_to_context = get_adapter_feature(cond, adapter)
|
|
if opt.cond_path_src:
|
|
adapter_features_src, append_to_context_src = get_adapter_feature(cond_src, adapter)
|
|
|
|
if opt.cond_path_src:
|
|
print("using reference guidance to synthesize image")
|
|
adapter_features = [torch.cat([adapter_features_src[i], adapter_features[i]]) for i in range(len(adapter_features))]
|
|
else:
|
|
adapter_features = [torch.cat([torch.zeros_like(feats), feats]) for feats in adapter_features]
|
|
|
|
if opt.scale > 1.:
|
|
adapter_features = [torch.cat([feats] * 2) for feats in adapter_features]
|
|
|
|
|
|
if opt.prompt_src is not None:
|
|
prompts = [opt.prompt_src, opt.prompt]
|
|
else:
|
|
prompts = [opt.prompt] * 2
|
|
print("promts: ", prompts)
|
|
|
|
c = sd_model.get_learned_conditioning(prompts)
|
|
if opt.scale != 1.0:
|
|
uc = sd_model.get_learned_conditioning([""] * len(prompts))
|
|
else:
|
|
uc = None
|
|
c, uc = fix_cond_shapes(sd_model, c, uc)
|
|
|
|
if not hasattr(opt, 'H'):
|
|
opt.H = 512
|
|
opt.W = 512
|
|
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
|
if opt.src_img_path:
|
|
|
|
src_img = read_image(opt.src_img_path)
|
|
src_img = src_img.float() / 255.
|
|
src_img = src_img * 2 - 1
|
|
if src_img.dim() == 3:
|
|
src_img = src_img.unsqueeze(0)
|
|
src_img = F.interpolate(src_img, (opt.H, opt.W))
|
|
src_img = src_img.to(opt.device)
|
|
|
|
encoder_posterior = sd_model.encode_first_stage(src_img)
|
|
src_x_0 = sd_model.get_first_stage_encoding(encoder_posterior)
|
|
start_code, latents_dict = sampler.ddim_sampling_reverse(
|
|
num_steps=opt.steps,
|
|
x_0=src_x_0,
|
|
conditioning=uc[:1],
|
|
unconditional_guidance_scale=opt.scale,
|
|
unconditional_conditioning=uc[:1],
|
|
)
|
|
torch.save(
|
|
{
|
|
"start_code": start_code
|
|
},
|
|
os.path.join(opt.outdir, "start_code.pth"),
|
|
)
|
|
elif opt.start_code_path:
|
|
|
|
start_code_dict = torch.load(opt.start_code_path)
|
|
start_code = start_code_dict.get("start_code").to(opt.device)
|
|
else:
|
|
start_code = torch.randn([1, *shape], device=opt.device)
|
|
start_code = start_code.expand(len(prompts), -1, -1, -1)
|
|
|
|
|
|
editor = MutualSelfAttentionControl(STEP, LAYER)
|
|
regiter_attention_editor_ldm(sd_model, editor)
|
|
|
|
samples_latents, _ = sampler.sample(
|
|
S=opt.steps,
|
|
conditioning=c,
|
|
batch_size=len(prompts),
|
|
shape=shape,
|
|
verbose=False,
|
|
unconditional_guidance_scale=opt.scale,
|
|
unconditional_conditioning=uc,
|
|
x_T=start_code,
|
|
features_adapter=adapter_features,
|
|
append_to_context=append_to_context,
|
|
cond_tau=opt.cond_tau,
|
|
style_cond_tau=opt.style_cond_tau,
|
|
)
|
|
|
|
x_samples = sd_model.decode_first_stage(samples_latents)
|
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_all_result.png'), tensor2img(x_samples))
|
|
|
|
with open(os.path.join(opt.outdir, "log.txt"), "w") as f:
|
|
for prom in prompts:
|
|
f.write(prom)
|
|
f.write("\n")
|
|
f.write(f"seed: {opt.seed}")
|
|
for i in range(len(x_samples)):
|
|
base_count += 1
|
|
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_result.png'), tensor2img(x_samples[i]))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|