CharlieAmalet commited on
Commit
d2eaa46
·
verified ·
1 Parent(s): 7e8f68e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  torch.jit.script = lambda f: f
3
  from zoedepth.utils.misc import colorize, save_raw_16bit
4
  from zoedepth.utils.geometry import depth_to_points, create_triangles
 
5
  import gradio as gr
6
  import spaces
7
 
@@ -29,6 +30,9 @@ css = """
29
  DEVICE = 'cuda'
30
  model = torch.hub.load('isl-org/ZoeDepth', "ZoeD_N", pretrained=True).to("cpu").eval()
31
 
 
 
 
32
  # ----------- Depth functions
33
  @spaces.GPU(enable_queue=True)
34
  def save_raw_16bit(depth, fpath="raw.png"):
@@ -46,10 +50,18 @@ def process_image(image: Image.Image):
46
  global model
47
  image = image.convert("RGB")
48
 
49
- model.to(DEVICE)
50
- out = model.infer_pil(image)
 
 
 
 
 
 
 
 
 
51
 
52
- processed_array = save_raw_16bit(colorize(out)[:, :, 0])
53
  return Image.fromarray(processed_array)
54
 
55
  # ----------- Depth functions
 
2
  torch.jit.script = lambda f: f
3
  from zoedepth.utils.misc import colorize, save_raw_16bit
4
  from zoedepth.utils.geometry import depth_to_points, create_triangles
5
+ from diffusers import DiffusionPipeline
6
  import gradio as gr
7
  import spaces
8
 
 
30
  DEVICE = 'cuda'
31
  model = torch.hub.load('isl-org/ZoeDepth', "ZoeD_N", pretrained=True).to("cpu").eval()
32
 
33
+ CHECKPOINT = "prs-eth/marigold-v1-0"
34
+ pipe = DiffusionPipeline.from_pretrained(CHECKPOINT)
35
+
36
  # ----------- Depth functions
37
  @spaces.GPU(enable_queue=True)
38
  def save_raw_16bit(depth, fpath="raw.png"):
 
50
  global model
51
  image = image.convert("RGB")
52
 
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ # model.to(device)
55
+ # depth = model.infer_pil(image)
56
+
57
+ # processed_array = save_raw_16bit(colorize(depth)[:, :, 0])
58
+ # return Image.fromarray(processed_array)
59
+
60
+ model.to(device)
61
+
62
+ # # inference
63
+ processed_array = pipe(image)["depth"]
64
 
 
65
  return Image.fromarray(processed_array)
66
 
67
  # ----------- Depth functions