Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import numpy as np | |
import torch | |
from misc import torch_samps_to_imgs | |
from adapt import Karras, ScoreAdapter, power_schedule | |
from adapt_gddpm import GuidedDDPM | |
from adapt_ncsn import NCSN as _NCSN | |
# from adapt_vesde import VESDE # not included to prevent import conflicts | |
from adapt_sd import StableDiffusion | |
from my.utils import tqdm, EventStorage, HeartBeat, EarlyLoopBreak | |
from my.config import BaseConf, dispatch | |
from my.utils.seed import seed_everything | |
class GDDPM(BaseConf): | |
"""Guided DDPM from OpenAI""" | |
model: str = "m_lsun_256" | |
lsun_cat: str = "bedroom" | |
imgnet_cat: int = -1 | |
def make(self): | |
args = self.dict() | |
model = GuidedDDPM(**args) | |
return model | |
class SD(BaseConf): | |
"""Stable Diffusion""" | |
variant: str = "v1" | |
v2_highres: bool = False | |
prompt: str = "a photograph of an astronaut riding a horse" | |
scale: float = 3.0 # classifier free guidance scale | |
precision: str = 'autocast' | |
def make(self): | |
args = self.dict() | |
model = StableDiffusion(**args) | |
return model | |
class SDE(BaseConf): | |
def make(self): | |
args = self.dict() | |
model = VESDE(**args) | |
return model | |
class NCSN(BaseConf): | |
def make(self): | |
args = self.dict() | |
model = _NCSN(**args) | |
return model | |
class KarrasGen(BaseConf): | |
family: str = "gddpm" | |
gddpm: GDDPM = GDDPM() | |
sd: SD = SD() | |
# sde: SDE = SDE() | |
ncsn: NCSN = NCSN() | |
batch_size: int = 10 | |
num_images: int = 1250 | |
num_t: int = 40 | |
σ_max: float = 80.0 | |
heun: bool = True | |
langevin: bool = False | |
cls_scaling: float = 1.0 # classifier guidance scaling | |
def run(self): | |
args = self.dict() | |
family = args.pop("family") | |
model = getattr(self, family).make() | |
self.karras_generate(model, **args) | |
def karras_generate( | |
model: ScoreAdapter, | |
batch_size, num_images, σ_max, num_t, langevin, heun, cls_scaling, | |
**kwargs | |
): | |
del kwargs # removed extra args | |
num_batches = num_images // batch_size | |
fuse = EarlyLoopBreak(5) | |
with tqdm(total=num_batches) as pbar, \ | |
HeartBeat(pbar) as hbeat, \ | |
EventStorage() as metric: | |
all_imgs = [] | |
for _ in range(num_batches): | |
if fuse.on_break(): | |
break | |
pipeline = Karras.inference( | |
model, batch_size, num_t, | |
init_xs=None, heun=heun, σ_max=σ_max, | |
langevin=langevin, cls_scaling=cls_scaling | |
) | |
for imgs in tqdm(pipeline, total=num_t+1, disable=False): | |
# _std = imgs.std().item() | |
# print(_std) | |
hbeat.beat() | |
pass | |
if isinstance(model, StableDiffusion): | |
imgs = model.decode(imgs) | |
imgs = torch_samps_to_imgs(imgs, uncenter=model.samps_centered()) | |
all_imgs.append(imgs) | |
pbar.update() | |
all_imgs = np.concatenate(all_imgs, axis=0) | |
metric.put_artifact("imgs", ".npy", lambda fn: np.save(fn, all_imgs)) | |
metric.step() | |
hbeat.done() | |
class SMLDGen(BaseConf): | |
family: str = "ncsn" | |
gddpm: GDDPM = GDDPM() | |
# sde: SDE = SDE() | |
ncsn: NCSN = NCSN() | |
batch_size: int = 16 | |
num_images: int = 16 | |
num_stages: int = 80 | |
num_steps: int = 15 | |
σ_max: float = 80.0 | |
ε: float = 1e-5 | |
def run(self): | |
args = self.dict() | |
family = args.pop("family") | |
model = getattr(self, family).make() | |
self.smld_generate(model, **args) | |
def smld_generate( | |
model: ScoreAdapter, | |
batch_size, num_images, num_stages, num_steps, σ_max, ε, | |
**kwargs | |
): | |
num_batches = num_images // batch_size | |
σs = power_schedule(σ_max, model.σ_min, num_stages) | |
σs = [model.snap_t_to_nearest_tick(σ)[0] for σ in σs] | |
fuse = EarlyLoopBreak(5) | |
with tqdm(total=num_batches) as pbar, \ | |
HeartBeat(pbar) as hbeat, \ | |
EventStorage() as metric: | |
all_imgs = [] | |
for _ in range(num_batches): | |
if fuse.on_break(): | |
break | |
init_xs = torch.rand(batch_size, *model.data_shape(), device=model.device) | |
if model.samps_centered(): | |
init_xs = init_xs * 2 - 1 # [0, 1] -> [-1, 1] | |
pipeline = smld_inference( | |
model, σs, num_steps, ε, init_xs | |
) | |
for imgs in tqdm(pipeline, total=(num_stages * num_steps)+1, disable=False): | |
pbar.set_description(f"{imgs.max().item():.3f}") | |
metric.put_scalars( | |
max=imgs.max().item(), min=imgs.min().item(), std=imgs.std().item() | |
) | |
metric.step() | |
hbeat.beat() | |
pbar.update() | |
imgs = torch_samps_to_imgs(imgs, uncenter=model.samps_centered()) | |
all_imgs.append(imgs) | |
all_imgs = np.concatenate(all_imgs, axis=0) | |
metric.put_artifact("imgs", ".npy", lambda fn: np.save(fn, all_imgs)) | |
metric.step() | |
hbeat.done() | |
def smld_inference(model, σs, num_steps, ε, init_xs): | |
from math import sqrt | |
# not doing conditioning or cls guidance; for gddpm only lsun works; fine. | |
xs = init_xs | |
yield xs | |
for i in range(len(σs)): | |
α_i = ε * ((σs[i] / σs[-1]) ** 2) | |
for _ in range(num_steps): | |
grad = model.score(xs, σs[i]) | |
z = torch.randn_like(xs) | |
xs = xs + α_i * grad + sqrt(2 * α_i) * z | |
yield xs | |
def load_np_imgs(fname): | |
fname = Path(fname) | |
data = np.load(fname) | |
if fname.suffix == ".npz": | |
imgs = data['arr_0'] | |
else: | |
imgs = data | |
return imgs | |
def visualize(max_n_imgs=16): | |
import torchvision.utils as vutils | |
from imageio import imwrite | |
from einops import rearrange | |
all_imgs = load_np_imgs("imgs/step_0.npy") | |
imgs = all_imgs[:max_n_imgs] | |
imgs = rearrange(imgs, "N H W C -> N C H W", C=3) | |
imgs = torch.from_numpy(imgs) | |
pane = vutils.make_grid(imgs, padding=2, nrow=4) | |
pane = rearrange(pane, "C H W -> H W C", C=3) | |
pane = pane.numpy() | |
imwrite("preview.jpg", pane) | |
if __name__ == "__main__": | |
seed_everything(0) | |
dispatch(KarrasGen) | |
visualize(16) | |