#!/usr/bin/env python import pathlib import gradio as gr import matplotlib as mpl import numpy as np import PIL.Image import spaces import torch from gradio_imageslider import ImageSlider from transformers import DepthProForDepthEstimation, DepthProImageProcessorFast device = torch.device("cuda" if torch.cuda.is_available() else "cpu") image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf") model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device) cmap = mpl.colormaps.get_cmap("Spectral_r") @spaces.GPU(duration=20) @torch.inference_mode() def run(image: PIL.Image.Image) -> tuple[tuple[PIL.Image.Image, PIL.Image.Image], str, str, str, str]: inputs = image_processor(images=image, return_tensors="pt").to(device) outputs = model(**inputs) post_processed_output = image_processor.post_process_depth_estimation( outputs, target_sizes=[(image.height, image.width)], ) depth_raw = post_processed_output[0]["predicted_depth"] depth_min = depth_raw.min().item() depth_max = depth_raw.max().item() inverse_depth = 1 / depth_raw normalized_inverse_depth = (inverse_depth - inverse_depth.min()) / (inverse_depth.max() - inverse_depth.min()) normalized_inverse_depth = normalized_inverse_depth * 255.0 normalized_inverse_depth = normalized_inverse_depth.detach().cpu().numpy() normalized_inverse_depth = PIL.Image.fromarray(normalized_inverse_depth.astype("uint8")) colored_inverse_depth = PIL.Image.fromarray( (cmap(np.array(normalized_inverse_depth))[:, :, :3] * 255).astype(np.uint8) ) field_of_view = post_processed_output[0]["field_of_view"].item() focal_length = post_processed_output[0]["focal_length"].item() return ( (image, colored_inverse_depth), f"{field_of_view:.2f}", f"{focal_length:.2f}", f"{depth_min:.2f}", f"{depth_max:.2f}", ) with gr.Blocks(css="style.css") as demo: gr.Markdown("# DepthPro") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil") run_button = gr.Button() with gr.Column(): output_image = ImageSlider() with gr.Row(): output_field_of_view = gr.Textbox(label="Field of View") output_focal_length = gr.Textbox(label="Focal Length") output_depth_min = gr.Textbox(label="Depth Min") output_depth_max = gr.Textbox(label="Depth Max") gr.Examples( examples=sorted(pathlib.Path("images").glob("*.jpg")), inputs=input_image, fn=run, outputs=[ output_image, output_field_of_view, output_focal_length, output_depth_min, output_depth_max, ], ) run_button.click( fn=run, inputs=input_image, outputs=[ output_image, output_field_of_view, output_focal_length, output_depth_min, output_depth_max, ], ) if __name__ == "__main__": demo.queue().launch()