Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import imageio | |
import torch | |
import os | |
import matplotlib.pyplot as plt | |
import cv2 | |
from promptda.utils.logger import Log | |
DEVICE = 'cuda' if torch.cuda.is_available( | |
) else 'mps' if torch.backends.mps.is_available() else 'cpu' | |
def to_tensor_func(arr): | |
if arr.ndim == 2: | |
arr = arr[:, :, np.newaxis] | |
return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(DEVICE) | |
def to_numpy_func(tensor): | |
arr = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
if arr.shape[2] == 1: | |
arr = arr[:, :, 0] | |
return arr | |
def ensure_multiple_of(x, multiple_of=14): | |
return int(x // multiple_of * multiple_of) | |
def load_image(image_path, to_tensor=True, max_size=1008, multiple_of=14): | |
''' | |
Load image from path and convert to tensor | |
max_size // 14 = 0 | |
''' | |
image = np.asarray(imageio.imread(image_path)).astype(np.float32) | |
image = image / 255. | |
max_size = max_size // multiple_of * multiple_of | |
if max(image.shape) > max_size: | |
h, w = image.shape[:2] | |
scale = max_size / max(h, w) | |
tar_h = ensure_multiple_of(h * scale) | |
tar_w = ensure_multiple_of(w * scale) | |
image = cv2.resize(image, (tar_w, tar_h), interpolation=cv2.INTER_AREA) | |
if to_tensor: | |
return to_tensor_func(image) | |
return image | |
def load_depth(depth_path, to_tensor=True): | |
''' | |
Load depth from path and convert to tensor | |
''' | |
if depth_path.endswith('.png'): | |
depth = np.asarray(imageio.imread(depth_path)).astype(np.float32) | |
depth = depth / 1000. | |
elif depth_path.endswith('.npz'): | |
depth = np.load(depth_path)['depth'] | |
else: | |
raise ValueError(f"Unsupported depth format: {depth_path}") | |
if to_tensor: | |
return to_tensor_func(depth) | |
return depth | |
def save_depth(depth, | |
prompt_depth=None, | |
image=None, | |
output_path='data/output/depth.png', | |
save_vis=True): | |
''' | |
Save depth to path | |
''' | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
depth = to_numpy_func(depth) | |
uint16_depth = (depth * 1000.).astype(np.uint16) | |
imageio.imwrite(output_path, uint16_depth) | |
if not save_vis: | |
return | |
output_path = output_path.replace('.png', '_vis.png') | |
prompt_depth = to_numpy_func(prompt_depth) | |
image = to_numpy_func(image) | |
plt.subplot(1, 3, 1) | |
plt.imshow(image) | |
plt.axis('off') | |
plt.subplot(1, 3, 2) | |
plt.imshow(prompt_depth) | |
plt.axis('off') | |
plt.subplot(1, 3, 3) | |
plt.imshow(depth) | |
plt.axis('off') | |
plt.tight_layout() | |
plt.savefig(output_path) | |
plt.close() | |
Log.info(f'Saved depth to {output_path}') | |