Spaces:
Runtime error
Runtime error
pablovela5620
commited on
Commit
•
8c45713
1
Parent(s):
c2a846f
Refactor predict_normal function and add DSINE demo
Browse files
main.py
CHANGED
@@ -20,7 +20,7 @@ model = utils.load_checkpoint("./checkpoints/dsine.pt", model)
|
|
20 |
model.eval()
|
21 |
|
22 |
|
23 |
-
def predict_normal(img_np: np.ndarray):
|
24 |
# normalize
|
25 |
normalize = transforms.Normalize(
|
26 |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
@@ -28,7 +28,7 @@ def predict_normal(img_np: np.ndarray):
|
|
28 |
|
29 |
with torch.no_grad():
|
30 |
img = np.array(img_np).astype(np.float32) / 255.0
|
31 |
-
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(
|
32 |
_, _, orig_H, orig_W = img.shape
|
33 |
|
34 |
# zero-pad the input image so that both the width and height are multiples of 32
|
@@ -39,7 +39,7 @@ def predict_normal(img_np: np.ndarray):
|
|
39 |
# NOTE: if intrins is not given, we just assume that the principal point is at the center
|
40 |
# and that the field-of-view is 60 degrees (feel free to modify this assumption)
|
41 |
intrins = utils.get_intrins_from_fov(
|
42 |
-
new_fov=60.0, H=orig_H, W=orig_W, device=
|
43 |
).unsqueeze(0)
|
44 |
|
45 |
intrins[:, 0, 2] += l
|
@@ -48,7 +48,6 @@ def predict_normal(img_np: np.ndarray):
|
|
48 |
pred_norm = model(img, intrins=intrins)[-1]
|
49 |
pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W]
|
50 |
|
51 |
-
# save to output folder
|
52 |
# NOTE: by saving the prediction as uint8 png format, you lose a lot of precision
|
53 |
# if you want to use the predicted normals for downstream tasks, we recommend saving them as float32 NPY files
|
54 |
pred_norm_np = (
|
@@ -60,6 +59,12 @@ def predict_normal(img_np: np.ndarray):
|
|
60 |
|
61 |
|
62 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
with gr.Group():
|
64 |
with gr.Row():
|
65 |
input_img = gr.Image(label="Input image", image_mode="RGB")
|
@@ -71,9 +76,12 @@ with gr.Blocks() as demo:
|
|
71 |
|
72 |
with Modal(visible=True, allow_user_close=False) as modal:
|
73 |
gr.Markdown(
|
74 |
-
"
|
|
|
|
|
|
|
75 |
)
|
76 |
-
btn = gr.Button("I
|
77 |
btn.click(lambda: Modal(visible=False), None, modal)
|
78 |
|
79 |
if __name__ == "__main__":
|
|
|
20 |
model.eval()
|
21 |
|
22 |
|
23 |
+
def predict_normal(img_np: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
24 |
# normalize
|
25 |
normalize = transforms.Normalize(
|
26 |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
|
28 |
|
29 |
with torch.no_grad():
|
30 |
img = np.array(img_np).astype(np.float32) / 255.0
|
31 |
+
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device)
|
32 |
_, _, orig_H, orig_W = img.shape
|
33 |
|
34 |
# zero-pad the input image so that both the width and height are multiples of 32
|
|
|
39 |
# NOTE: if intrins is not given, we just assume that the principal point is at the center
|
40 |
# and that the field-of-view is 60 degrees (feel free to modify this assumption)
|
41 |
intrins = utils.get_intrins_from_fov(
|
42 |
+
new_fov=60.0, H=orig_H, W=orig_W, device=device
|
43 |
).unsqueeze(0)
|
44 |
|
45 |
intrins[:, 0, 2] += l
|
|
|
48 |
pred_norm = model(img, intrins=intrins)[-1]
|
49 |
pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W]
|
50 |
|
|
|
51 |
# NOTE: by saving the prediction as uint8 png format, you lose a lot of precision
|
52 |
# if you want to use the predicted normals for downstream tasks, we recommend saving them as float32 NPY files
|
53 |
pred_norm_np = (
|
|
|
59 |
|
60 |
|
61 |
with gr.Blocks() as demo:
|
62 |
+
gr.Markdown(
|
63 |
+
"""
|
64 |
+
# DSINE
|
65 |
+
Unofficial Gradio demo of [DSINE: Rethinking Inductive Biases for Surface Normal Estimation](https://github.com/baegwangbin/DSINE)
|
66 |
+
"""
|
67 |
+
)
|
68 |
with gr.Group():
|
69 |
with gr.Row():
|
70 |
input_img = gr.Image(label="Input image", image_mode="RGB")
|
|
|
76 |
|
77 |
with Modal(visible=True, allow_user_close=False) as modal:
|
78 |
gr.Markdown(
|
79 |
+
"""
|
80 |
+
To use this space, you must agree to the terms and conditions.
|
81 |
+
Found [HERE](https://github.com/baegwangbin/DSINE/blob/main/LICENSE).
|
82 |
+
""",
|
83 |
)
|
84 |
+
btn = gr.Button("I Agree to the Terms and Conditions")
|
85 |
btn.click(lambda: Modal(visible=False), None, modal)
|
86 |
|
87 |
if __name__ == "__main__":
|