akhaliq3
New message for the combined commit
833ef7e
raw
history blame
3.12 kB
import os
import torch
import numpy as np
from tqdm import trange
from PIL import Image
def get_state(gpu):
import torch
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
if gpu:
midas.cuda()
midas.eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.default_transform
state = {"model": midas,
"transform": transform}
return state
def depth_to_rgba(x):
assert x.dtype == np.float32
assert len(x.shape) == 2
y = x.copy()
y.dtype = np.uint8
y = y.reshape(x.shape+(4,))
return np.ascontiguousarray(y)
def rgba_to_depth(x):
assert x.dtype == np.uint8
assert len(x.shape) == 3 and x.shape[2] == 4
y = x.copy()
y.dtype = np.float32
y = y.reshape(x.shape[:2])
return np.ascontiguousarray(y)
def run(x, state):
model = state["model"]
transform = state["transform"]
hw = x.shape[:2]
with torch.no_grad():
prediction = model(transform((x + 1.0) * 127.5).cuda())
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=hw,
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
return output
def get_filename(relpath, level=-2):
# save class folder structure and filename:
fn = relpath.split(os.sep)[level:]
folder = fn[-2]
file = fn[-1].split('.')[0]
return folder, file
def save_depth(dataset, path, debug=False):
os.makedirs(path)
N = len(dset)
if debug:
N = 10
state = get_state(gpu=True)
for idx in trange(N, desc="Data"):
ex = dataset[idx]
image, relpath = ex["image"], ex["relpath"]
folder, filename = get_filename(relpath)
# prepare
folderabspath = os.path.join(path, folder)
os.makedirs(folderabspath, exist_ok=True)
savepath = os.path.join(folderabspath, filename)
# run model
xout = run(image, state)
I = depth_to_rgba(xout)
Image.fromarray(I).save("{}.png".format(savepath))
if __name__ == "__main__":
from taming.data.imagenet import ImageNetTrain, ImageNetValidation
out = "data/imagenet_depth"
if not os.path.exists(out):
print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
"(be prepared that the output size will be larger than ImageNet itself).")
exit(1)
# go
dset = ImageNetValidation()
abspath = os.path.join(out, "val")
if os.path.exists(abspath):
print("{} exists - not doing anything.".format(abspath))
else:
print("preparing {}".format(abspath))
save_depth(dset, abspath)
print("done with validation split")
dset = ImageNetTrain()
abspath = os.path.join(out, "train")
if os.path.exists(abspath):
print("{} exists - not doing anything.".format(abspath))
else:
print("preparing {}".format(abspath))
save_depth(dset, abspath)
print("done with train split")
print("done done.")