depth / app.py
besarismaili's picture
Upload 2 files
4b82956
raw
history blame
2.48 kB
import gradio as gr
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
import torch
import numpy as np
from PIL import Image
import os
import cv2
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
def get_image_depth(image):
# prepare image for the model
encoding = feature_extractor(image, return_tensors="pt")
# forward pass
with torch.no_grad():
outputs = model(**encoding)
predicted_depth = outputs.predicted_depth
# interpolate to original size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
formatted = (output * 255 / np.max(output)).astype('uint8')
img = Image.fromarray(formatted)
return img
def process_sequence(files):
file_paths = [file.name for file in files]
for file_path in file_paths:
image = Image.open(file_path)
depth_image = get_image_depth(image)
depth_image.save(os.path.join('output', os.path.basename(file_path)))
return file_paths, gr.Info("This is some info")
title = "# Depth estimation demo"
description = "Demo for Intel's DPT"
with gr.Blocks() as iface:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column():
with gr.Tab(label='Singel image'):
image = gr.Image(type="pil")
button = gr.Button(value="Get depth", interactive=True, variant="primary")
image_output=gr.Image(type="pil", label="predicted depth")
with gr.Column():
with gr.Tab(label='Frames'):
file_output = gr.File(visible=False)
upload_button = gr.UploadButton("Select directory", file_types=["image"], file_count="directory")
upload_button.upload(process_sequence, upload_button, file_output)
#output=gr.Video(label="Predicted Depth")
message=gr.Text(value="Check output folder for the depth frames.")
button.click(
fn=get_image_depth,
inputs=[image],
outputs=[image_output]
)
iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True)