sashasax commited on
Commit
da917e1
·
1 Parent(s): 46b342d

update readme

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +21 -7
  3. requirements.txt +2 -1
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Omnidata Monocular Surface Normal Dpt Hybrid 384
3
  emoji: 🐠
4
  colorFrom: green
5
  colorTo: purple
@@ -9,5 +9,5 @@ app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-4.0
11
  ---
12
-
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Omnidata Monocular Depth DPT Hybrid 384
3
  emoji: 🐠
4
  colorFrom: green
5
  colorTo: purple
 
9
  pinned: false
10
  license: cc-by-nc-4.0
11
  ---
12
+ # [Use these models in your code:](https://github.com/alexsax/omnidata_models/tree/main)
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -4,13 +4,14 @@ import torch.nn.functional as F
4
  from torchvision import transforms
5
  import PIL
6
  from PIL import Image
7
- import os
 
8
  from typing import Tuple
9
 
10
 
11
  def setup_model(device: torch.device) -> Tuple[torch.nn.Module, int]:
12
  image_size = 384
13
- model = torch.hub.load('alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384')
14
  model.to(device)
15
  model.eval()
16
 
@@ -21,13 +22,14 @@ def setup_transforms(image_size: int) -> transforms.Compose:
21
  transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
22
  transforms.CenterCrop(image_size),
23
  transforms.ToTensor(),
 
24
  ])
25
 
26
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
  model, image_size = setup_model(device)
28
  trans_totensor = setup_transforms(image_size)
29
 
30
- def estimate_surface_normal(input_image: PIL.Image.Image) -> PIL.Image.Image:
31
  with torch.no_grad():
32
  img_tensor = trans_totensor(input_image)[:3].unsqueeze(0).to(device)
33
 
@@ -35,16 +37,28 @@ def estimate_surface_normal(input_image: PIL.Image.Image) -> PIL.Image.Image:
35
  img_tensor = img_tensor.repeat_interleave(3, 1)
36
 
37
  output = model(img_tensor).clamp(min=0, max=1)
38
- output_image = transforms.ToPILImage()(output[0])
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  return output_image
41
 
42
  iface = gr.Interface(
43
- fn=estimate_surface_normal,
44
  inputs=gr.Image(type="pil"),
45
  outputs=gr.Image(type="pil"),
46
- title="Monocular Surface Normal Estimation: Omnidata DPT-Hybrid",
47
- description="Upload an image to estimate monocular surface normals.",
48
  examples=[
49
  "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/test1_rgb.png?raw=true",
50
  "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test2.png?raw=true",
 
4
  from torchvision import transforms
5
  import PIL
6
  from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ import io
9
  from typing import Tuple
10
 
11
 
12
  def setup_model(device: torch.device) -> Tuple[torch.nn.Module, int]:
13
  image_size = 384
14
+ model = torch.hub.load('alexsax/omnidata_models', 'depth_dpt_hybrid_384')
15
  model.to(device)
16
  model.eval()
17
 
 
22
  transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
23
  transforms.CenterCrop(image_size),
24
  transforms.ToTensor(),
25
+ transforms.Normalize(mean=0.5, std=0.5)
26
  ])
27
 
28
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
  model, image_size = setup_model(device)
30
  trans_totensor = setup_transforms(image_size)
31
 
32
+ def estimate_depth(input_image: PIL.Image.Image) -> PIL.Image.Image:
33
  with torch.no_grad():
34
  img_tensor = trans_totensor(input_image)[:3].unsqueeze(0).to(device)
35
 
 
37
  img_tensor = img_tensor.repeat_interleave(3, 1)
38
 
39
  output = model(img_tensor).clamp(min=0, max=1)
40
+ output = F.interpolate(output.unsqueeze(0), (512, 512), mode='bicubic').squeeze(0)
41
+ output = 1 - output.clamp(0, 1)
42
+
43
+ # Convert to colormap
44
+ plt.figure(figsize=(10, 10))
45
+ plt.imshow(output[0].cpu().numpy(), cmap='viridis')
46
+ plt.axis('off')
47
+
48
+ buf = io.BytesIO()
49
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
50
+ buf.seek(0)
51
+ output_image = Image.open(buf)
52
+ plt.close()
53
 
54
  return output_image
55
 
56
  iface = gr.Interface(
57
+ fn=estimate_depth,
58
  inputs=gr.Image(type="pil"),
59
  outputs=gr.Image(type="pil"),
60
+ title="Monocular Depth Estimation: Omnidata DPT-Hybrid",
61
+ description="Upload an image to estimate monocular depth.",
62
  examples=[
63
  "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/test1_rgb.png?raw=true",
64
  "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test2.png?raw=true",
requirements.txt CHANGED
@@ -2,4 +2,5 @@ torch>=1.9.0
2
  torchvision>=0.10.0
3
  timm==0.4.12
4
  pillow
5
- requests
 
 
2
  torchvision>=0.10.0
3
  timm==0.4.12
4
  pillow
5
+ requests
6
+ matplotlib