marigold-dc / marigold_dc.py
toshas's picture
Initial commit
1619d3a
import logging
import warnings
import diffusers
import numpy as np
import torch
from diffusers import MarigoldDepthPipeline
warnings.simplefilter(action="ignore", category=FutureWarning)
diffusers.utils.logging.disable_progress_bar()
class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline):
def __call__(
self,
image,
sparse_depth,
num_inference_steps=50,
processing_resolution=0,
seed=2024,
dry_run=False,
):
# Resolving variables
device = self._execution_device
generator = torch.Generator(device=device).manual_seed(seed)
if dry_run:
logging.warning("Dry run mode")
for i in range(num_inference_steps):
yield np.array(image)[:, :, 0].astype(float), float(np.log(i + 1))
return
# Check inputs.
if num_inference_steps is None:
raise ValueError("Invalid num_inference_steps")
if type(sparse_depth) is not np.ndarray or sparse_depth.ndim != 2:
raise ValueError(
"Sparse depth should be a 2D numpy ndarray with zeros at missing positions"
)
with torch.no_grad():
# Prepare empty text conditioning
if self.empty_text_embedding is None:
prompt = ""
text_inputs = self.tokenizer(
prompt,
padding="do_not_pad",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
self.empty_text_embedding = self.text_encoder(text_input_ids)[
0
] # [1,2,1024]
# Preprocess input images
image, padding, original_resolution = self.image_processor.preprocess(
image,
processing_resolution=processing_resolution,
device=device,
dtype=self.dtype,
) # [N,3,PPH,PPW]
if sparse_depth.shape != original_resolution:
raise ValueError(
f"Sparse depth dimensions ({sparse_depth.shape}) must match that of the image ({image.shape[-2:]})"
)
with torch.no_grad():
# Encode input image into latent space
image_latent, pred_latent = self.prepare_latents(
image, None, generator, 1, 1
) # [N*E,4,h,w], [N*E,4,h,w]
del image
# Preprocess sparse depth
sparse_depth = torch.from_numpy(sparse_depth)[None, None].float()
sparse_depth = sparse_depth.to(device)
sparse_mask = sparse_depth > 0
# Set up optimization targets
scale = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True)
shift = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True)
pred_latent = torch.nn.Parameter(pred_latent, requires_grad=True)
sparse_range = (
sparse_depth[sparse_mask].max() - sparse_depth[sparse_mask].min()
).item()
sparse_lower = (sparse_depth[sparse_mask].min()).item()
def affine_to_metric(depth):
return (scale**2) * sparse_range * depth + (shift**2) * sparse_lower
def latent_to_metric(latent):
affine_invariant_prediction = self.decode_prediction(
latent
) # [E,1,PPH,PPW]
prediction = affine_to_metric(affine_invariant_prediction)
prediction = self.image_processor.unpad_image(
prediction, padding
) # [E,1,PH,PW]
prediction = self.image_processor.resize_antialias(
prediction, original_resolution, "bilinear", is_aa=False
) # [1,1,H,W]
return prediction
def loss_l1l2(input, target):
out_l1 = torch.nn.functional.l1_loss(input, target)
out_l2 = torch.nn.functional.mse_loss(input, target)
out = out_l1 + out_l2
return out, out_l2.sqrt()
optimizer = torch.optim.Adam(
[
{"params": [scale, shift], "lr": 0.005},
{"params": [pred_latent], "lr": 0.05},
]
)
# Process the denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
for iter, t in enumerate(
self.progress_bar(
self.scheduler.timesteps, desc=f"Marigold-DC steps ({str(device)})..."
)
):
optimizer.zero_grad()
batch_latent = torch.cat([image_latent, pred_latent], dim=1) # [1,8,h,w]
noise = self.unet(
batch_latent,
t,
encoder_hidden_states=self.empty_text_embedding,
return_dict=False,
)[
0
] # [1,4,h,w]
# Compute pred_epsilon to later rescale the depth latent gradient
with torch.no_grad():
alpha_prod_t = self.scheduler.alphas_cumprod[t]
beta_prod_t = 1 - alpha_prod_t
pred_epsilon = (alpha_prod_t**0.5) * noise + (
beta_prod_t**0.5
) * pred_latent
step_output = self.scheduler.step(
noise, t, pred_latent, generator=generator
)
# Preview the final output depth, compute loss with guidance, backprop
pred_original_sample = step_output.pred_original_sample
current_metric_estimate = latent_to_metric(pred_original_sample)
loss, rmse = loss_l1l2(
current_metric_estimate[sparse_mask], sparse_depth[sparse_mask]
)
loss.backward()
# Scale gradients up
with torch.no_grad():
pred_epsilon_norm = torch.linalg.norm(pred_epsilon).item()
depth_latent_grad_norm = torch.linalg.norm(pred_latent.grad).item()
scaling_factor = pred_epsilon_norm / max(depth_latent_grad_norm, 1e-8)
pred_latent.grad *= scaling_factor
optimizer.step()
with torch.no_grad():
pred_latent.data = self.scheduler.step(
noise, t, pred_latent, generator=generator
).prev_sample
yield current_metric_estimate, rmse.item()
del (
pred_original_sample,
current_metric_estimate,
step_output,
pred_epsilon,
noise,
)
torch.cuda.empty_cache()
# Offload all models
self.maybe_free_model_hooks()