Diffree / stable_diffusion /scripts /sample_diffusion.py
LiruiZhao's picture
update
2d0e22d
raw
history blame
9.61 kB
import argparse, os, sys, glob, datetime, yaml
import torch
import time
import numpy as np
from tqdm import trange
from omegaconf import OmegaConf
from PIL import Image
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
rescale = lambda x: (x + 1.) / 2.
def custom_to_pil(x):
x = x.detach().cpu()
x = torch.clamp(x, -1., 1.)
x = (x + 1.) / 2.
x = x.permute(1, 2, 0).numpy()
x = (255 * x).astype(np.uint8)
x = Image.fromarray(x)
if not x.mode == "RGB":
x = x.convert("RGB")
return x
def custom_to_np(x):
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
sample = x.detach().cpu()
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous()
return sample
def logs2pil(logs, keys=["sample"]):
imgs = dict()
for k in logs:
try:
if len(logs[k].shape) == 4:
img = custom_to_pil(logs[k][0, ...])
elif len(logs[k].shape) == 3:
img = custom_to_pil(logs[k])
else:
print(f"Unknown format for key {k}. ")
img = None
except:
img = None
imgs[k] = img
return imgs
@torch.no_grad()
def convsample(model, shape, return_intermediates=True,
verbose=True,
make_prog_row=False):
if not make_prog_row:
return model.p_sample_loop(None, shape,
return_intermediates=return_intermediates, verbose=verbose)
else:
return model.progressive_denoising(
None, shape, verbose=True
)
@torch.no_grad()
def convsample_ddim(model, steps, shape, eta=1.0
):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,):
log = dict()
shape = [batch_size,
model.model.diffusion_model.in_channels,
model.model.diffusion_model.image_size,
model.model.diffusion_model.image_size]
with model.ema_scope("Plotting"):
t0 = time.time()
if vanilla:
sample, progrow = convsample(model, shape,
make_prog_row=True)
else:
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
eta=eta)
t1 = time.time()
x_sample = model.decode_first_stage(sample)
log["sample"] = x_sample
log["time"] = t1 - t0
log['throughput'] = sample.shape[0] / (t1 - t0)
print(f'Throughput for this batch: {log["throughput"]}')
return log
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
if vanilla:
print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
else:
print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')
tstart = time.time()
n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
# path = logdir
if model.cond_stage_model is None:
all_images = []
print(f"Running unconditional sampling for {n_samples} samples")
for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
logs = make_convolutional_sample(model, batch_size=batch_size,
vanilla=vanilla, custom_steps=custom_steps,
eta=eta)
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
all_images.extend([custom_to_np(logs["sample"])])
if n_saved >= n_samples:
print(f'Finish after generating {n_saved} samples')
break
all_img = np.concatenate(all_images, axis=0)
all_img = all_img[:n_samples]
shape_str = "x".join([str(x) for x in all_img.shape])
nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
np.savez(nppath, all_img)
else:
raise NotImplementedError('Currently only sampling for unconditional models supported.')
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
for k in logs:
if k == key:
batch = logs[key]
if np_path is None:
for x in batch:
img = custom_to_pil(x)
imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
img.save(imgpath)
n_saved += 1
else:
npbatch = custom_to_np(batch)
shape_str = "x".join([str(x) for x in npbatch.shape])
nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
np.savez(nppath, npbatch)
n_saved += npbatch.shape[0]
return n_saved
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--resume",
type=str,
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-n",
"--n_samples",
type=int,
nargs="?",
help="number of samples to draw",
default=50000
)
parser.add_argument(
"-e",
"--eta",
type=float,
nargs="?",
help="eta for ddim sampling (0.0 yields deterministic sampling)",
default=1.0
)
parser.add_argument(
"-v",
"--vanilla_sample",
default=False,
action='store_true',
help="vanilla sampling (default option is DDIM sampling)?",
)
parser.add_argument(
"-l",
"--logdir",
type=str,
nargs="?",
help="extra logdir",
default="none"
)
parser.add_argument(
"-c",
"--custom_steps",
type=int,
nargs="?",
help="number of steps for ddim and fastdpm sampling",
default=50
)
parser.add_argument(
"--batch_size",
type=int,
nargs="?",
help="the bs",
default=10
)
return parser
def load_model_from_config(config, sd):
model = instantiate_from_config(config)
model.load_state_dict(sd,strict=False)
model.cuda()
model.eval()
return model
def load_model(config, ckpt, gpu, eval_mode):
if ckpt:
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model,
pl_sd["state_dict"])
return model, global_step
if __name__ == "__main__":
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
sys.path.append(os.getcwd())
command = " ".join(sys.argv)
parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
# paths = opt.resume.split("/")
try:
logdir = '/'.join(opt.resume.split('/')[:-1])
# idx = len(paths)-paths[::-1].index("logs")+1
print(f'Logdir is {logdir}')
except ValueError:
paths = opt.resume.split("/")
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "model.ckpt")
base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
opt.base = base_configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
gpu = True
eval_mode = True
if opt.logdir != "none":
locallog = logdir.split(os.sep)[-1]
if locallog == "": locallog = logdir.split(os.sep)[-2]
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
logdir = os.path.join(opt.logdir, locallog)
print(config)
model, global_step = load_model(config, ckpt, gpu, eval_mode)
print(f"global step: {global_step}")
print(75 * "=")
print("logging to:")
logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
imglogdir = os.path.join(logdir, "img")
numpylogdir = os.path.join(logdir, "numpy")
os.makedirs(imglogdir)
os.makedirs(numpylogdir)
print(logdir)
print(75 * "=")
# write config out
sampling_file = os.path.join(logdir, "sampling_config.yaml")
sampling_conf = vars(opt)
with open(sampling_file, 'w') as f:
yaml.dump(sampling_conf, f, default_flow_style=False)
print(sampling_conf)
run(model, imglogdir, eta=opt.eta,
vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
batch_size=opt.batch_size, nplog=numpylogdir)
print("done.")