Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import warnings | |
import diffusers | |
import numpy as np | |
import torch | |
from diffusers import MarigoldDepthPipeline | |
warnings.simplefilter(action="ignore", category=FutureWarning) | |
diffusers.utils.logging.disable_progress_bar() | |
class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline): | |
def __call__( | |
self, | |
image, | |
sparse_depth, | |
num_inference_steps=50, | |
processing_resolution=0, | |
seed=2024, | |
dry_run=False, | |
): | |
# Resolving variables | |
device = self._execution_device | |
generator = torch.Generator(device=device).manual_seed(seed) | |
if dry_run: | |
logging.warning("Dry run mode") | |
for i in range(num_inference_steps): | |
yield np.array(image)[:, :, 0].astype(float), float(np.log(i + 1)) | |
return | |
# Check inputs. | |
if num_inference_steps is None: | |
raise ValueError("Invalid num_inference_steps") | |
if type(sparse_depth) is not np.ndarray or sparse_depth.ndim != 2: | |
raise ValueError( | |
"Sparse depth should be a 2D numpy ndarray with zeros at missing positions" | |
) | |
with torch.no_grad(): | |
# Prepare empty text conditioning | |
if self.empty_text_embedding is None: | |
prompt = "" | |
text_inputs = self.tokenizer( | |
prompt, | |
padding="do_not_pad", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids.to(device) | |
self.empty_text_embedding = self.text_encoder(text_input_ids)[ | |
0 | |
] # [1,2,1024] | |
# Preprocess input images | |
image, padding, original_resolution = self.image_processor.preprocess( | |
image, | |
processing_resolution=processing_resolution, | |
device=device, | |
dtype=self.dtype, | |
) # [N,3,PPH,PPW] | |
if sparse_depth.shape != original_resolution: | |
raise ValueError( | |
f"Sparse depth dimensions ({sparse_depth.shape}) must match that of the image ({image.shape[-2:]})" | |
) | |
with torch.no_grad(): | |
# Encode input image into latent space | |
image_latent, pred_latent = self.prepare_latents( | |
image, None, generator, 1, 1 | |
) # [N*E,4,h,w], [N*E,4,h,w] | |
del image | |
# Preprocess sparse depth | |
sparse_depth = torch.from_numpy(sparse_depth)[None, None].float() | |
sparse_depth = sparse_depth.to(device) | |
sparse_mask = sparse_depth > 0 | |
# Set up optimization targets | |
scale = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True) | |
shift = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True) | |
pred_latent = torch.nn.Parameter(pred_latent, requires_grad=True) | |
sparse_range = ( | |
sparse_depth[sparse_mask].max() - sparse_depth[sparse_mask].min() | |
).item() | |
sparse_lower = (sparse_depth[sparse_mask].min()).item() | |
def affine_to_metric(depth): | |
return (scale**2) * sparse_range * depth + (shift**2) * sparse_lower | |
def latent_to_metric(latent): | |
affine_invariant_prediction = self.decode_prediction( | |
latent | |
) # [E,1,PPH,PPW] | |
prediction = affine_to_metric(affine_invariant_prediction) | |
prediction = self.image_processor.unpad_image( | |
prediction, padding | |
) # [E,1,PH,PW] | |
prediction = self.image_processor.resize_antialias( | |
prediction, original_resolution, "bilinear", is_aa=False | |
) # [1,1,H,W] | |
return prediction | |
def loss_l1l2(input, target): | |
out_l1 = torch.nn.functional.l1_loss(input, target) | |
out_l2 = torch.nn.functional.mse_loss(input, target) | |
out = out_l1 + out_l2 | |
return out, out_l2.sqrt() | |
optimizer = torch.optim.Adam( | |
[ | |
{"params": [scale, shift], "lr": 0.005}, | |
{"params": [pred_latent], "lr": 0.05}, | |
] | |
) | |
# Process the denoising loop | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
for iter, t in enumerate( | |
self.progress_bar( | |
self.scheduler.timesteps, desc=f"Marigold-DC steps ({str(device)})..." | |
) | |
): | |
optimizer.zero_grad() | |
batch_latent = torch.cat([image_latent, pred_latent], dim=1) # [1,8,h,w] | |
noise = self.unet( | |
batch_latent, | |
t, | |
encoder_hidden_states=self.empty_text_embedding, | |
return_dict=False, | |
)[ | |
0 | |
] # [1,4,h,w] | |
# Compute pred_epsilon to later rescale the depth latent gradient | |
with torch.no_grad(): | |
alpha_prod_t = self.scheduler.alphas_cumprod[t] | |
beta_prod_t = 1 - alpha_prod_t | |
pred_epsilon = (alpha_prod_t**0.5) * noise + ( | |
beta_prod_t**0.5 | |
) * pred_latent | |
step_output = self.scheduler.step( | |
noise, t, pred_latent, generator=generator | |
) | |
# Preview the final output depth, compute loss with guidance, backprop | |
pred_original_sample = step_output.pred_original_sample | |
current_metric_estimate = latent_to_metric(pred_original_sample) | |
loss, rmse = loss_l1l2( | |
current_metric_estimate[sparse_mask], sparse_depth[sparse_mask] | |
) | |
loss.backward() | |
# Scale gradients up | |
with torch.no_grad(): | |
pred_epsilon_norm = torch.linalg.norm(pred_epsilon).item() | |
depth_latent_grad_norm = torch.linalg.norm(pred_latent.grad).item() | |
scaling_factor = pred_epsilon_norm / max(depth_latent_grad_norm, 1e-8) | |
pred_latent.grad *= scaling_factor | |
optimizer.step() | |
with torch.no_grad(): | |
pred_latent.data = self.scheduler.step( | |
noise, t, pred_latent, generator=generator | |
).prev_sample | |
yield current_metric_estimate, rmse.item() | |
del ( | |
pred_original_sample, | |
current_metric_estimate, | |
step_output, | |
pred_epsilon, | |
noise, | |
) | |
torch.cuda.empty_cache() | |
# Offload all models | |
self.maybe_free_model_hooks() | |