update readme
Browse files- README.md +2 -2
- app.py +21 -7
- requirements.txt +2 -1
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title: Omnidata Monocular
|
3 |
emoji: 🐠
|
4 |
colorFrom: green
|
5 |
colorTo: purple
|
@@ -9,5 +9,5 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
license: cc-by-nc-4.0
|
11 |
---
|
12 |
-
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Omnidata Monocular Depth DPT Hybrid 384
|
3 |
emoji: 🐠
|
4 |
colorFrom: green
|
5 |
colorTo: purple
|
|
|
9 |
pinned: false
|
10 |
license: cc-by-nc-4.0
|
11 |
---
|
12 |
+
# [Use these models in your code:](https://github.com/alexsax/omnidata_models/tree/main)
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -4,13 +4,14 @@ import torch.nn.functional as F
|
|
4 |
from torchvision import transforms
|
5 |
import PIL
|
6 |
from PIL import Image
|
7 |
-
import
|
|
|
8 |
from typing import Tuple
|
9 |
|
10 |
|
11 |
def setup_model(device: torch.device) -> Tuple[torch.nn.Module, int]:
|
12 |
image_size = 384
|
13 |
-
model = torch.hub.load('alexsax/omnidata_models', '
|
14 |
model.to(device)
|
15 |
model.eval()
|
16 |
|
@@ -21,13 +22,14 @@ def setup_transforms(image_size: int) -> transforms.Compose:
|
|
21 |
transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
|
22 |
transforms.CenterCrop(image_size),
|
23 |
transforms.ToTensor(),
|
|
|
24 |
])
|
25 |
|
26 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
27 |
model, image_size = setup_model(device)
|
28 |
trans_totensor = setup_transforms(image_size)
|
29 |
|
30 |
-
def
|
31 |
with torch.no_grad():
|
32 |
img_tensor = trans_totensor(input_image)[:3].unsqueeze(0).to(device)
|
33 |
|
@@ -35,16 +37,28 @@ def estimate_surface_normal(input_image: PIL.Image.Image) -> PIL.Image.Image:
|
|
35 |
img_tensor = img_tensor.repeat_interleave(3, 1)
|
36 |
|
37 |
output = model(img_tensor).clamp(min=0, max=1)
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
return output_image
|
41 |
|
42 |
iface = gr.Interface(
|
43 |
-
fn=
|
44 |
inputs=gr.Image(type="pil"),
|
45 |
outputs=gr.Image(type="pil"),
|
46 |
-
title="Monocular
|
47 |
-
description="Upload an image to estimate monocular
|
48 |
examples=[
|
49 |
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/test1_rgb.png?raw=true",
|
50 |
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test2.png?raw=true",
|
|
|
4 |
from torchvision import transforms
|
5 |
import PIL
|
6 |
from PIL import Image
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import io
|
9 |
from typing import Tuple
|
10 |
|
11 |
|
12 |
def setup_model(device: torch.device) -> Tuple[torch.nn.Module, int]:
|
13 |
image_size = 384
|
14 |
+
model = torch.hub.load('alexsax/omnidata_models', 'depth_dpt_hybrid_384')
|
15 |
model.to(device)
|
16 |
model.eval()
|
17 |
|
|
|
22 |
transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
|
23 |
transforms.CenterCrop(image_size),
|
24 |
transforms.ToTensor(),
|
25 |
+
transforms.Normalize(mean=0.5, std=0.5)
|
26 |
])
|
27 |
|
28 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
29 |
model, image_size = setup_model(device)
|
30 |
trans_totensor = setup_transforms(image_size)
|
31 |
|
32 |
+
def estimate_depth(input_image: PIL.Image.Image) -> PIL.Image.Image:
|
33 |
with torch.no_grad():
|
34 |
img_tensor = trans_totensor(input_image)[:3].unsqueeze(0).to(device)
|
35 |
|
|
|
37 |
img_tensor = img_tensor.repeat_interleave(3, 1)
|
38 |
|
39 |
output = model(img_tensor).clamp(min=0, max=1)
|
40 |
+
output = F.interpolate(output.unsqueeze(0), (512, 512), mode='bicubic').squeeze(0)
|
41 |
+
output = 1 - output.clamp(0, 1)
|
42 |
+
|
43 |
+
# Convert to colormap
|
44 |
+
plt.figure(figsize=(10, 10))
|
45 |
+
plt.imshow(output[0].cpu().numpy(), cmap='viridis')
|
46 |
+
plt.axis('off')
|
47 |
+
|
48 |
+
buf = io.BytesIO()
|
49 |
+
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
50 |
+
buf.seek(0)
|
51 |
+
output_image = Image.open(buf)
|
52 |
+
plt.close()
|
53 |
|
54 |
return output_image
|
55 |
|
56 |
iface = gr.Interface(
|
57 |
+
fn=estimate_depth,
|
58 |
inputs=gr.Image(type="pil"),
|
59 |
outputs=gr.Image(type="pil"),
|
60 |
+
title="Monocular Depth Estimation: Omnidata DPT-Hybrid",
|
61 |
+
description="Upload an image to estimate monocular depth.",
|
62 |
examples=[
|
63 |
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/test1_rgb.png?raw=true",
|
64 |
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test2.png?raw=true",
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ torch>=1.9.0
|
|
2 |
torchvision>=0.10.0
|
3 |
timm==0.4.12
|
4 |
pillow
|
5 |
-
requests
|
|
|
|
2 |
torchvision>=0.10.0
|
3 |
timm==0.4.12
|
4 |
pillow
|
5 |
+
requests
|
6 |
+
matplotlib
|