|
import gradio as gr |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
import PIL |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import io |
|
from typing import Tuple |
|
|
|
|
|
def setup_model(device: torch.device) -> Tuple[torch.nn.Module, int]: |
|
image_size = 384 |
|
model = torch.hub.load('alexsax/omnidata_models', 'depth_dpt_hybrid_384') |
|
model.to(device) |
|
model.eval() |
|
|
|
return model, image_size |
|
|
|
def setup_transforms(image_size: int) -> transforms.Compose: |
|
return transforms.Compose([ |
|
transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR), |
|
transforms.CenterCrop(image_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=0.5, std=0.5) |
|
]) |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model, image_size = setup_model(device) |
|
trans_totensor = setup_transforms(image_size) |
|
|
|
def estimate_depth(input_image: PIL.Image.Image) -> PIL.Image.Image: |
|
with torch.no_grad(): |
|
img_tensor = trans_totensor(input_image)[:3].unsqueeze(0).to(device) |
|
|
|
if img_tensor.shape[1] == 1: |
|
img_tensor = img_tensor.repeat_interleave(3, 1) |
|
|
|
output = model(img_tensor).clamp(min=0, max=1) |
|
output = F.interpolate(output.unsqueeze(0), (512, 512), mode='bicubic').squeeze(0) |
|
output = 1 - output.clamp(0, 1) |
|
|
|
|
|
plt.figure(figsize=(10, 10)) |
|
plt.imshow(output[0].cpu().numpy(), cmap='viridis') |
|
plt.axis('off') |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
buf.seek(0) |
|
output_image = Image.open(buf) |
|
plt.close() |
|
|
|
return output_image |
|
|
|
iface = gr.Interface( |
|
fn=estimate_depth, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Image(type="pil"), |
|
title="Monocular Depth Estimation: Omnidata DPT-Hybrid", |
|
description="Upload an image to estimate monocular depth. To use these models locally, you can use `torch.hub.load`. Code and examples in our [Github](https://github.com/alexsax/omnidata_models) repository. More information and the paper in the project page [Omnidata: A Scalable Pipeline for Making Multi-Task Mid-Level Vision Datasets from 3D Scans](https://omnidata.epfl.ch/).", |
|
examples=[ |
|
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/test1_rgb.png?raw=true", |
|
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test2.png?raw=true", |
|
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test3.png?raw=true", |
|
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test4.png?raw=true", |
|
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test5.png?raw=true", |
|
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test6.png?raw=true", |
|
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test7.png?raw=true", |
|
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test8.png?raw=true", |
|
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test9.png?raw=true", |
|
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test10.png?raw=true", |
|
], |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|