|
from transformers import CLIPTextModel, CLIPTokenizer, logging |
|
from diffusers import ( |
|
AutoencoderKL, |
|
UNet2DConditionModel, |
|
DDIMScheduler, |
|
StableDiffusionPipeline, |
|
) |
|
import torchvision.transforms.functional as TF |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import sys |
|
sys.path.append('./') |
|
|
|
from zero123 import Zero123Pipeline |
|
|
|
|
|
class Zero123(nn.Module): |
|
def __init__(self, device, fp16=True, t_range=[0.02, 0.98]): |
|
super().__init__() |
|
|
|
self.device = device |
|
self.fp16 = fp16 |
|
self.dtype = torch.float16 if fp16 else torch.float32 |
|
|
|
self.pipe = Zero123Pipeline.from_pretrained( |
|
|
|
"bennyguo/zero123-xl-diffusers", |
|
|
|
variant="fp16_ema" if self.fp16 else None, |
|
torch_dtype=self.dtype, |
|
).to(self.device) |
|
|
|
|
|
|
|
|
|
self.pipe.image_encoder.eval() |
|
self.pipe.vae.eval() |
|
self.pipe.unet.eval() |
|
self.pipe.clip_camera_projection.eval() |
|
|
|
self.vae = self.pipe.vae |
|
self.unet = self.pipe.unet |
|
|
|
self.pipe.set_progress_bar_config(disable=True) |
|
|
|
self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
|
self.num_train_timesteps = self.scheduler.config.num_train_timesteps |
|
|
|
self.min_step = int(self.num_train_timesteps * t_range[0]) |
|
self.max_step = int(self.num_train_timesteps * t_range[1]) |
|
self.alphas = self.scheduler.alphas_cumprod.to(self.device) |
|
|
|
self.embeddings = None |
|
|
|
@torch.no_grad() |
|
def get_img_embeds(self, x): |
|
|
|
x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False) |
|
x_pil = [TF.to_pil_image(image) for image in x] |
|
x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype) |
|
c = self.pipe.image_encoder(x_clip).image_embeds |
|
v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor |
|
self.embeddings = [c, v] |
|
|
|
@torch.no_grad() |
|
def refine(self, pred_rgb, polar, azimuth, radius, |
|
guidance_scale=5, steps=50, strength=0.8, |
|
): |
|
|
|
batch_size = pred_rgb.shape[0] |
|
|
|
self.scheduler.set_timesteps(steps) |
|
|
|
if strength == 0: |
|
init_step = 0 |
|
latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype) |
|
else: |
|
init_step = int(steps * strength) |
|
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) |
|
latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) |
|
latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) |
|
|
|
T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) |
|
T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) |
|
cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) |
|
cc_emb = self.pipe.clip_camera_projection(cc_emb) |
|
cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) |
|
|
|
vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1) |
|
vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) |
|
|
|
for i, t in enumerate(self.scheduler.timesteps[init_step:]): |
|
|
|
x_in = torch.cat([latents] * 2) |
|
t_in = torch.cat([t.view(1)] * 2).to(self.device) |
|
|
|
noise_pred = self.unet( |
|
torch.cat([x_in, vae_emb], dim=1), |
|
t_in.to(self.unet.dtype), |
|
encoder_hidden_states=cc_emb, |
|
).sample |
|
|
|
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
imgs = self.decode_latents(latents) |
|
return imgs |
|
|
|
def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None, guidance_scale=5, as_latent=False): |
|
|
|
|
|
batch_size = pred_rgb.shape[0] |
|
|
|
if as_latent: |
|
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 |
|
else: |
|
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) |
|
latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) |
|
|
|
if step_ratio is not None: |
|
|
|
|
|
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) |
|
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) |
|
else: |
|
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) |
|
|
|
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) |
|
|
|
with torch.no_grad(): |
|
noise = torch.randn_like(latents) |
|
latents_noisy = self.scheduler.add_noise(latents, noise, t) |
|
|
|
x_in = torch.cat([latents_noisy] * 2) |
|
t_in = torch.cat([t] * 2) |
|
|
|
T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) |
|
T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) |
|
cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) |
|
cc_emb = self.pipe.clip_camera_projection(cc_emb) |
|
cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) |
|
|
|
vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1) |
|
vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) |
|
|
|
noise_pred = self.unet( |
|
torch.cat([x_in, vae_emb], dim=1), |
|
t_in.to(self.unet.dtype), |
|
encoder_hidden_states=cc_emb, |
|
).sample |
|
|
|
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
grad = w * (noise_pred - noise) |
|
grad = torch.nan_to_num(grad) |
|
|
|
target = (latents - grad).detach() |
|
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') |
|
|
|
return loss |
|
|
|
|
|
def decode_latents(self, latents): |
|
latents = 1 / self.vae.config.scaling_factor * latents |
|
|
|
imgs = self.vae.decode(latents).sample |
|
imgs = (imgs / 2 + 0.5).clamp(0, 1) |
|
|
|
return imgs |
|
|
|
def encode_imgs(self, imgs, mode=False): |
|
|
|
|
|
imgs = 2 * imgs - 1 |
|
|
|
posterior = self.vae.encode(imgs).latent_dist |
|
if mode: |
|
latents = posterior.mode() |
|
else: |
|
latents = posterior.sample() |
|
latents = latents * self.vae.config.scaling_factor |
|
|
|
return latents |
|
|
|
|
|
if __name__ == '__main__': |
|
import cv2 |
|
import argparse |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('input', type=str) |
|
parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]') |
|
parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]') |
|
parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]') |
|
|
|
opt = parser.parse_args() |
|
|
|
device = torch.device('cuda') |
|
|
|
print(f'[INFO] loading image from {opt.input} ...') |
|
image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA) |
|
image = image.astype(np.float32) / 255.0 |
|
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device) |
|
|
|
print(f'[INFO] loading model ...') |
|
zero123 = Zero123(device) |
|
|
|
print(f'[INFO] running model ...') |
|
zero123.get_img_embeds(image) |
|
|
|
while True: |
|
outputs = zero123.refine(image, polar=[opt.polar], azimuth=[opt.azimuth], radius=[opt.radius], strength=0) |
|
plt.imshow(outputs.float().cpu().numpy().transpose(0, 2, 3, 1)[0]) |
|
plt.show() |