Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
import pathlib | |
import gradio as gr | |
import matplotlib as mpl | |
import numpy as np | |
import PIL.Image | |
import spaces | |
import torch | |
from gradio_imageslider import ImageSlider | |
from transformers import DepthProForDepthEstimation, DepthProImageProcessorFast | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf") | |
model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device) | |
cmap = mpl.colormaps.get_cmap("Spectral_r") | |
def run(image: PIL.Image.Image) -> tuple[tuple[PIL.Image.Image, PIL.Image.Image], str, str, str, str]: | |
inputs = image_processor(images=image, return_tensors="pt").to(device) | |
outputs = model(**inputs) | |
post_processed_output = image_processor.post_process_depth_estimation( | |
outputs, | |
target_sizes=[(image.height, image.width)], | |
) | |
depth_raw = post_processed_output[0]["predicted_depth"] | |
depth_min = depth_raw.min().item() | |
depth_max = depth_raw.max().item() | |
inverse_depth = 1 / depth_raw | |
normalized_inverse_depth = (inverse_depth - inverse_depth.min()) / (inverse_depth.max() - inverse_depth.min()) | |
normalized_inverse_depth = normalized_inverse_depth * 255.0 | |
normalized_inverse_depth = normalized_inverse_depth.detach().cpu().numpy() | |
normalized_inverse_depth = PIL.Image.fromarray(normalized_inverse_depth.astype("uint8")) | |
colored_inverse_depth = PIL.Image.fromarray( | |
(cmap(np.array(normalized_inverse_depth))[:, :, :3] * 255).astype(np.uint8) | |
) | |
field_of_view = post_processed_output[0]["field_of_view"].item() | |
focal_length = post_processed_output[0]["focal_length"].item() | |
return ( | |
(image, colored_inverse_depth), | |
f"{field_of_view:.2f}", | |
f"{focal_length:.2f}", | |
f"{depth_min:.2f}", | |
f"{depth_max:.2f}", | |
) | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown("# DepthPro") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil") | |
run_button = gr.Button() | |
with gr.Column(): | |
output_image = ImageSlider() | |
with gr.Row(): | |
output_field_of_view = gr.Textbox(label="Field of View") | |
output_focal_length = gr.Textbox(label="Focal Length") | |
output_depth_min = gr.Textbox(label="Depth Min") | |
output_depth_max = gr.Textbox(label="Depth Max") | |
gr.Examples( | |
examples=sorted(pathlib.Path("images").glob("*.jpg")), | |
inputs=input_image, | |
fn=run, | |
outputs=[ | |
output_image, | |
output_field_of_view, | |
output_focal_length, | |
output_depth_min, | |
output_depth_max, | |
], | |
) | |
run_button.click( | |
fn=run, | |
inputs=input_image, | |
outputs=[ | |
output_image, | |
output_field_of_view, | |
output_focal_length, | |
output_depth_min, | |
output_depth_max, | |
], | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() | |