T2I-Adapter / ldm /inference_base.py
Adapter's picture
support composable adapter (#5)
b3478e4
raw
history blame
8 kB
import argparse
import torch
from omegaconf import OmegaConf
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.encoders.adapter import Adapter, StyleAdapter, Adapter_light
from ldm.modules.extra_condition.api import ExtraCondition
from ldm.util import fix_cond_shapes, load_model_from_config, read_state_dict
DEFAULT_NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
'fewer digits, cropped, worst quality, low quality'
def get_base_argument_parser() -> argparse.ArgumentParser:
"""get the base argument parser for inference scripts"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--outdir',
type=str,
help='dir to write results to',
default=None,
)
parser.add_argument(
'--prompt',
type=str,
nargs='?',
default=None,
help='positive prompt',
)
parser.add_argument(
'--neg_prompt',
type=str,
default=DEFAULT_NEGATIVE_PROMPT,
help='negative prompt',
)
parser.add_argument(
'--cond_path',
type=str,
default=None,
help='condition image path',
)
parser.add_argument(
'--cond_inp_type',
type=str,
default='image',
help='the type of the input condition image, take depth T2I as example, the input can be raw image, '
'which depth will be calculated, or the input can be a directly a depth map image',
)
parser.add_argument(
'--sampler',
type=str,
default='ddim',
choices=['ddim', 'plms'],
help='sampling algorithm, currently, only ddim and plms are supported, more are on the way',
)
parser.add_argument(
'--steps',
type=int,
default=50,
help='number of sampling steps',
)
parser.add_argument(
'--sd_ckpt',
type=str,
default='models/sd-v1-4.ckpt',
help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported',
)
parser.add_argument(
'--vae_ckpt',
type=str,
default=None,
help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded',
)
parser.add_argument(
'--adapter_ckpt',
type=str,
default=None,
help='path to checkpoint of adapter',
)
parser.add_argument(
'--config',
type=str,
default='configs/stable-diffusion/sd-v1-inference.yaml',
help='path to config which constructs SD model',
)
parser.add_argument(
'--max_resolution',
type=float,
default=512 * 512,
help='max image height * width, only for computer with limited vram',
)
parser.add_argument(
'--resize_short_edge',
type=int,
default=None,
help='resize short edge of the input image, if this arg is set, max_resolution will not be used',
)
parser.add_argument(
'--C',
type=int,
default=4,
help='latent channels',
)
parser.add_argument(
'--f',
type=int,
default=8,
help='downsampling factor',
)
parser.add_argument(
'--scale',
type=float,
default=7.5,
help='unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))',
)
parser.add_argument(
'--cond_tau',
type=float,
default=1.0,
help='timestamp parameter that determines until which step the adapter is applied, '
'similar as Prompt-to-Prompt tau')
parser.add_argument(
'--cond_weight',
type=float,
default=1.0,
help='the adapter features are multiplied by the cond_weight. The larger the cond_weight, the more aligned '
'the generated image and condition will be, but the generated quality may be reduced',
)
parser.add_argument(
'--seed',
type=int,
default=42,
)
parser.add_argument(
'--n_samples',
type=int,
default=4,
help='# of samples to generate',
)
return parser
def get_sd_models(opt):
"""
build stable diffusion model, sampler
"""
# SD
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt)
sd_model = model.to(opt.device)
# sampler
if opt.sampler == 'plms':
sampler = PLMSSampler(model)
elif opt.sampler == 'ddim':
sampler = DDIMSampler(model)
else:
raise NotImplementedError
return sd_model, sampler
def get_t2i_adapter_models(opt):
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt)
adapter_ckpt_path = getattr(opt, f'{opt.which_cond}_adapter_ckpt', None)
if adapter_ckpt_path is None:
adapter_ckpt_path = getattr(opt, 'adapter_ckpt')
adapter_ckpt = read_state_dict(adapter_ckpt_path)
new_state_dict = {}
for k, v in adapter_ckpt.items():
if not k.startswith('adapter.'):
new_state_dict[f'adapter.{k}'] = v
else:
new_state_dict[k] = v
m, u = model.load_state_dict(new_state_dict, strict=False)
if len(u) > 0:
print(f"unexpected keys in loading adapter ckpt {adapter_ckpt_path}:")
print(u)
model = model.to(opt.device)
# sampler
if opt.sampler == 'plms':
sampler = PLMSSampler(model)
elif opt.sampler == 'ddim':
sampler = DDIMSampler(model)
else:
raise NotImplementedError
return model, sampler
def get_cond_ch(cond_type: ExtraCondition):
if cond_type == ExtraCondition.sketch or cond_type == ExtraCondition.canny:
return 1
return 3
def get_adapters(opt, cond_type: ExtraCondition):
adapter = {}
cond_weight = getattr(opt, f'{cond_type.name}_weight', None)
if cond_weight is None:
cond_weight = getattr(opt, 'cond_weight')
adapter['cond_weight'] = cond_weight
if cond_type == ExtraCondition.style:
adapter['model'] = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(opt.device)
elif cond_type == ExtraCondition.color:
adapter['model'] = Adapter_light(
cin=64 * get_cond_ch(cond_type),
channels=[320, 640, 1280, 1280],
nums_rb=4).to(opt.device)
else:
adapter['model'] = Adapter(
cin=64 * get_cond_ch(cond_type),
channels=[320, 640, 1280, 1280][:4],
nums_rb=2,
ksize=1,
sk=True,
use_conv=False).to(opt.device)
ckpt_path = getattr(opt, f'{cond_type.name}_adapter_ckpt', None)
if ckpt_path is None:
ckpt_path = getattr(opt, 'adapter_ckpt')
adapter['model'].load_state_dict(torch.load(ckpt_path))
return adapter
def diffusion_inference(opt, model, sampler, adapter_features, append_to_context=None):
# get text embedding
c = model.get_learned_conditioning([opt.prompt])
if opt.scale != 1.0:
uc = model.get_learned_conditioning([opt.neg_prompt])
else:
uc = None
c, uc = fix_cond_shapes(model, c, uc)
if not hasattr(opt, 'H'):
opt.H = 512
opt.W = 512
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
samples_latents, _ = sampler.sample(
S=opt.steps,
conditioning=c,
batch_size=1,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
x_T=None,
features_adapter=adapter_features,
append_to_context=append_to_context,
cond_tau=opt.cond_tau,
)
x_samples = model.decode_first_stage(samples_latents)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
return x_samples