Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import torch | |
from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel | |
from mvadapter.pipelines.pipeline_mvadapter_t2mv_sdxl import MVAdapterT2MVSDXLPipeline | |
from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler | |
from mvadapter.utils import ( | |
get_orthogonal_camera, | |
get_plucker_embeds_from_cameras_ortho, | |
make_image_grid, | |
) | |
def prepare_pipeline( | |
base_model, | |
vae_model, | |
unet_model, | |
lora_model, | |
adapter_path, | |
scheduler, | |
num_views, | |
device, | |
dtype, | |
): | |
# Load vae and unet if provided | |
pipe_kwargs = {} | |
if vae_model is not None: | |
pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) | |
if unet_model is not None: | |
pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) | |
# Prepare pipeline | |
pipe: MVAdapterT2MVSDXLPipeline | |
pipe = MVAdapterT2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs) | |
# Load scheduler if provided | |
scheduler_class = None | |
if scheduler == "ddpm": | |
scheduler_class = DDPMScheduler | |
elif scheduler == "lcm": | |
scheduler_class = LCMScheduler | |
pipe.scheduler = ShiftSNRScheduler.from_scheduler( | |
pipe.scheduler, | |
shift_mode="interpolated", | |
shift_scale=8.0, | |
scheduler_class=scheduler_class, | |
) | |
pipe.init_custom_adapter(num_views=num_views) | |
pipe.load_custom_adapter( | |
adapter_path, weight_name="mvadapter_t2mv_sdxl.safetensors" | |
) | |
pipe.to(device=device, dtype=dtype) | |
pipe.cond_encoder.to(device=device, dtype=dtype) | |
# load lora if provided | |
if lora_model is not None: | |
model_, name_ = lora_model.rsplit("/", 1) | |
pipe.load_lora_weights(model_, weight_name=name_) | |
return pipe | |
def run_pipeline( | |
pipe, | |
num_views, | |
text, | |
height, | |
width, | |
num_inference_steps, | |
guidance_scale, | |
seed, | |
negative_prompt, | |
lora_scale=1.0, | |
device="cuda", | |
): | |
# Prepare cameras | |
cameras = get_orthogonal_camera( | |
elevation_deg=[0, 0, 0, 0, 0, 0], | |
distance=[1.8] * num_views, | |
left=-0.55, | |
right=0.55, | |
bottom=-0.55, | |
top=0.55, | |
azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]], | |
device=device, | |
) | |
plucker_embeds = get_plucker_embeds_from_cameras_ortho( | |
cameras.c2w, [1.1] * num_views, width | |
) | |
control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) | |
pipe_kwargs = {} | |
if seed != -1: | |
pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) | |
images = pipe( | |
text, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=num_views, | |
control_image=control_images, | |
control_conditioning_scale=1.0, | |
negative_prompt=negative_prompt, | |
cross_attention_kwargs={"scale": lora_scale}, | |
**pipe_kwargs, | |
).images | |
return images | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# Models | |
parser.add_argument( | |
"--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" | |
) | |
parser.add_argument( | |
"--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix" | |
) | |
parser.add_argument("--unet_model", type=str, default=None) | |
parser.add_argument("--scheduler", type=str, default=None) | |
parser.add_argument("--lora_model", type=str, default=None) | |
parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") | |
parser.add_argument("--num_views", type=int, default=6) | |
# Device | |
parser.add_argument("--device", type=str, default="cuda") | |
# Inference | |
parser.add_argument("--text", type=str, required=True) | |
parser.add_argument("--num_inference_steps", type=int, default=50) | |
parser.add_argument("--guidance_scale", type=float, default=7.0) | |
parser.add_argument("--seed", type=int, default=-1) | |
parser.add_argument( | |
"--negative_prompt", | |
type=str, | |
default="watermark, ugly, deformed, noisy, blurry, low contrast", | |
) | |
parser.add_argument("--lora_scale", type=float, default=1.0) | |
parser.add_argument("--output", type=str, default="output.png") | |
args = parser.parse_args() | |
pipe = prepare_pipeline( | |
base_model=args.base_model, | |
vae_model=args.vae_model, | |
unet_model=args.unet_model, | |
lora_model=args.lora_model, | |
adapter_path=args.adapter_path, | |
scheduler=args.scheduler, | |
num_views=args.num_views, | |
device=args.device, | |
dtype=torch.float16, | |
) | |
images = run_pipeline( | |
pipe, | |
num_views=args.num_views, | |
text=args.text, | |
height=768, | |
width=768, | |
num_inference_steps=args.num_inference_steps, | |
guidance_scale=args.guidance_scale, | |
seed=args.seed, | |
negative_prompt=args.negative_prompt, | |
lora_scale=args.lora_scale, | |
device=args.device, | |
) | |
make_image_grid(images, rows=1).save(args.output) | |