|
import hydra |
|
import torch |
|
import os |
|
from hydra.utils import instantiate |
|
from omegaconf import DictConfig |
|
from PIL import Image |
|
from torchvision import transforms as TF |
|
import glob |
|
from vggt.utils.pose_enc import pose_encoding_to_extri_intri |
|
from viser_fn import viser_wrapper |
|
|
|
|
|
|
|
def demo_fn(cfg: DictConfig, model) -> None: |
|
print(cfg.SCENE_DIR) |
|
|
|
if not torch.cuda.is_available(): |
|
raise ValueError("CUDA is not available. Check your environment.") |
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
else: |
|
device = "cpu" |
|
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_list = glob.glob(os.path.join(cfg.SCENE_DIR, "images", "*")) |
|
image_list = sorted(image_list) |
|
images = load_and_preprocess_images(image_list) |
|
images = images[None].to(device) |
|
|
|
|
|
batch = {"images": images} |
|
|
|
with torch.no_grad(): |
|
with torch.cuda.amp.autocast(dtype=torch.float16): |
|
y_hat = model(batch) |
|
|
|
|
|
last_pred_pose_enc = y_hat["pred_extrinsic_list"][-1] |
|
pose_encoding_type = cfg.CameraHead.pose_encoding_type |
|
|
|
last_pred_extrinsic, _ = pose_encoding_to_extri_intri(last_pred_pose_enc.detach(), None, pose_encoding_type=pose_encoding_type, build_intrinsics=False) |
|
|
|
y_hat["last_pred_extrinsic"] = last_pred_extrinsic |
|
|
|
|
|
for key in y_hat.keys(): |
|
if isinstance(y_hat[key], torch.Tensor): |
|
y_hat[key] = y_hat[key].cpu().numpy() |
|
|
|
return y_hat |
|
|
|
|
|
|
|
def load_and_preprocess_images(image_path_list): |
|
|
|
if len(image_path_list) == 0: |
|
raise ValueError("At least 1 image is required") |
|
|
|
|
|
|
|
|
|
|
|
images = [] |
|
shapes = set() |
|
to_tensor = TF.ToTensor() |
|
|
|
|
|
for image_path in image_path_list: |
|
img = Image.open(image_path).convert("RGB") |
|
width, height = img.size |
|
new_width = 518 |
|
|
|
|
|
new_height = round(height * (new_width / width) / 14) * 14 |
|
|
|
|
|
|
|
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) |
|
img = to_tensor(img) |
|
|
|
|
|
|
|
if new_height > 518: |
|
start_y = (new_height - 518) // 2 |
|
img = img[:, start_y:start_y + 518, :] |
|
|
|
shapes.add((img.shape[1], img.shape[2])) |
|
images.append(img) |
|
|
|
|
|
if len(shapes) > 1: |
|
print(f"Warning: Found images with different shapes: {shapes}") |
|
|
|
max_height = max(shape[0] for shape in shapes) |
|
max_width = max(shape[1] for shape in shapes) |
|
|
|
|
|
padded_images = [] |
|
for img in images: |
|
h_padding = max_height - img.shape[1] |
|
w_padding = max_width - img.shape[2] |
|
|
|
if h_padding > 0 or w_padding > 0: |
|
pad_top = h_padding // 2 |
|
pad_bottom = h_padding - pad_top |
|
pad_left = w_padding // 2 |
|
pad_right = w_padding - pad_left |
|
|
|
img = torch.nn.functional.pad( |
|
img, |
|
(pad_left, pad_right, pad_top, pad_bottom), |
|
mode='constant', |
|
value=1.0 |
|
) |
|
padded_images.append(img) |
|
images = padded_images |
|
|
|
|
|
images = torch.stack(images) |
|
|
|
|
|
if len(image_path_list) == 1: |
|
|
|
if images.dim() == 3: |
|
images = images.unsqueeze(0) |
|
|
|
return images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|