CharlieAmalet commited on
Commit
fe381bb
·
verified ·
1 Parent(s): 6af6ea2

Update depth.py

Browse files
Files changed (1) hide show
  1. depth.py +2 -2
depth.py CHANGED
@@ -24,9 +24,9 @@ def process_image(model, image: Image.Image):
24
  processed_array = save_raw_16bit(colorize(out)[:, :, 0])
25
  return Image.fromarray(processed_array)
26
 
27
- def depth_interface(model):
28
  with gr.Row():
29
  inputs=gr.Image(label="Input Image", type='pil') # Input is an image
30
  outputs=gr.Image(label="Depth Map", type='pil') # Output is also an image
31
  generate_btn = gr.Button(value="Generate")
32
- generate_btn.click(partial(process_image, model), inputs=inputs, outputs=outputs, api_name="generate_depth")
 
24
  processed_array = save_raw_16bit(colorize(out)[:, :, 0])
25
  return Image.fromarray(processed_array)
26
 
27
+ def depth_interface(model, device):
28
  with gr.Row():
29
  inputs=gr.Image(label="Input Image", type='pil') # Input is an image
30
  outputs=gr.Image(label="Depth Map", type='pil') # Output is also an image
31
  generate_btn = gr.Button(value="Generate")
32
+ generate_btn.click(partial(process_image, model.to(device)), inputs=inputs, outputs=outputs, api_name="generate_depth")