Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from tqdm import tqdm | |
from einops import rearrange, repeat | |
from omegaconf import OmegaConf | |
from diffusers import DDIMScheduler, StableDiffusionPipeline, DiffusionPipeline | |
from torchvision.utils import save_image | |
from torchvision.io import read_image | |
from pytorch_lightning import seed_everything | |
from masactrl.masactrl_processor import register_attention_processor | |
torch.cuda.set_device(0) # set the GPU device | |
# Note that you may add your Hugging Face token to get access to the models | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
weight_dtype = torch.float16 | |
model_path = "stabilityai/stable-diffusion-xl-base-1.0" | |
scheduler = DDIMScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
clip_sample=False, | |
set_alpha_to_one=False | |
) | |
pipe = DiffusionPipeline.from_pretrained( | |
model_path, | |
scheduler=scheduler, | |
torch_dtype=weight_dtype | |
).to(device) | |
def consistent_synthesis(): | |
seed = 42 | |
seed_everything(seed) | |
out_dir_ori = "./workdir/masactrl_exp/oldman_smiling" | |
os.makedirs(out_dir_ori, exist_ok=True) | |
prompts = [ | |
"A portrait of an old man, facing camera, best quality", | |
"A portrait of an old man, facing camera, smiling, best quality", | |
] | |
# inference the synthesized image with MasaCtrl | |
# TODO: note that the hyper paramerter of MasaCtrl for SDXL may be not optimal | |
STEP = 4 | |
LAYER_LIST = [44, 54, 64] # run the synthesis with MasaCtrl at three different layer configs | |
MODEL_TYPE = "SDXL" | |
# initialize the noise map | |
start_code = torch.randn([1, 4, 128, 128], dtype=weight_dtype, device=device) | |
# start_code = None | |
start_code = start_code.expand(len(prompts), -1, -1, -1) | |
# inference the synthesized image without MasaCtrl | |
image_ori = pipe(prompts, latents=start_code, guidance_scale=7.5).images | |
for LAYER in LAYER_LIST: | |
# hijack the attention module with MasaCtrl processor | |
processor_args = { | |
"start_step": STEP, | |
"start_layer": LAYER, | |
"model_type": MODEL_TYPE | |
} | |
register_attention_processor(pipe.unet, processor_type="MasaCtrlProcessor") | |
# inference the synthesized image | |
image_masactrl = pipe(prompts, latents=start_code, guidance_scale=7.5).images | |
sample_count = len(os.listdir(out_dir_ori)) | |
out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}") | |
os.makedirs(out_dir, exist_ok=True) | |
image_ori[0].save(os.path.join(out_dir, f"source_step{STEP}_layer{LAYER}.png")) | |
image_ori[1].save(os.path.join(out_dir, f"without_step{STEP}_layer{LAYER}.png")) | |
image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYER}.png")) | |
with open(os.path.join(out_dir, f"prompts.txt"), "w") as f: | |
for p in prompts: | |
f.write(p + "\n") | |
f.write(f"seed: {seed}\n") | |
print("Syntheiszed images are saved in", out_dir) | |
if __name__ == "__main__": | |
consistent_synthesis() | |