Spaces:
Runtime error
Runtime error
pablovela5620
commited on
Commit
•
169b74e
1
Parent(s):
8c45713
Delete test_samples function and its dependencies
Browse files
test.py
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import glob
|
4 |
-
import argparse
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import torch.nn.functional as F
|
9 |
-
from torchvision import transforms
|
10 |
-
from PIL import Image
|
11 |
-
import utils.utils as utils
|
12 |
-
|
13 |
-
|
14 |
-
def test_samples(args, model, intrins=None, device="cpu"):
|
15 |
-
img_paths = glob.glob("./samples/img/*.png") + glob.glob("./samples/img/*.jpg")
|
16 |
-
img_paths.sort()
|
17 |
-
|
18 |
-
# normalize
|
19 |
-
normalize = transforms.Normalize(
|
20 |
-
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
21 |
-
)
|
22 |
-
|
23 |
-
with torch.no_grad():
|
24 |
-
for img_path in img_paths:
|
25 |
-
print(img_path)
|
26 |
-
ext = os.path.splitext(img_path)[1]
|
27 |
-
img = Image.open(img_path).convert("RGB")
|
28 |
-
img = np.array(img).astype(np.float32) / 255.0
|
29 |
-
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device)
|
30 |
-
_, _, orig_H, orig_W = img.shape
|
31 |
-
|
32 |
-
# zero-pad the input image so that both the width and height are multiples of 32
|
33 |
-
l, r, t, b = utils.pad_input(orig_H, orig_W)
|
34 |
-
img = F.pad(img, (l, r, t, b), mode="constant", value=0.0)
|
35 |
-
img = normalize(img)
|
36 |
-
|
37 |
-
intrins_path = img_path.replace(ext, ".txt")
|
38 |
-
if os.path.exists(intrins_path):
|
39 |
-
# NOTE: camera intrinsics should be given as a txt file
|
40 |
-
# it should contain the values of fx, fy, cx, cy
|
41 |
-
intrins = utils.get_intrins_from_txt(
|
42 |
-
intrins_path, device=device
|
43 |
-
).unsqueeze(0)
|
44 |
-
else:
|
45 |
-
# NOTE: if intrins is not given, we just assume that the principal point is at the center
|
46 |
-
# and that the field-of-view is 60 degrees (feel free to modify this assumption)
|
47 |
-
intrins = utils.get_intrins_from_fov(
|
48 |
-
new_fov=60.0, H=orig_H, W=orig_W, device=device
|
49 |
-
).unsqueeze(0)
|
50 |
-
|
51 |
-
intrins[:, 0, 2] += l
|
52 |
-
intrins[:, 1, 2] += t
|
53 |
-
|
54 |
-
pred_norm = model(img, intrins=intrins)[-1]
|
55 |
-
pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W]
|
56 |
-
|
57 |
-
# save to output folder
|
58 |
-
# NOTE: by saving the prediction as uint8 png format, you lose a lot of precision
|
59 |
-
# if you want to use the predicted normals for downstream tasks, we recommend saving them as float32 NPY files
|
60 |
-
pred_norm_np = (
|
61 |
-
pred_norm.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0)
|
62 |
-
) # (H, W, 3)
|
63 |
-
pred_norm_np = ((pred_norm_np + 1.0) / 2.0 * 255.0).astype(np.uint8)
|
64 |
-
target_path = img_path.replace("/img/", "/output/").replace(ext, ".png")
|
65 |
-
im = Image.fromarray(pred_norm_np)
|
66 |
-
im.save(target_path)
|
67 |
-
|
68 |
-
|
69 |
-
if __name__ == "__main__":
|
70 |
-
parser = argparse.ArgumentParser()
|
71 |
-
parser.add_argument("--ckpt", default="dsine", type=str, help="model checkpoint")
|
72 |
-
parser.add_argument("--mode", default="samples", type=str, help="{samples}")
|
73 |
-
args = parser.parse_args()
|
74 |
-
|
75 |
-
# define model
|
76 |
-
device = torch.device("cpu")
|
77 |
-
|
78 |
-
from models.dsine import DSINE
|
79 |
-
|
80 |
-
model = DSINE().to(device)
|
81 |
-
model.pixel_coords = model.pixel_coords.to(device)
|
82 |
-
model = utils.load_checkpoint("./checkpoints/%s.pt" % args.ckpt, model)
|
83 |
-
model.eval()
|
84 |
-
|
85 |
-
if args.mode == "samples":
|
86 |
-
test_samples(args, model, intrins=None, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|