NeRF-MAE / README.md
mirshad7's picture
Add metadata and hf_hub_download (#1)
b9845dd verified
metadata
pipeline_tag: feature-extraction

arXiv Project Page Pytorch Cite Video


Project Page | arXiv | PDF

NeRF-MAE : Masked AutoEncoders for Self-Supervised 3D Representation Learning for Neural Radiance Fields

Muhammad Zubair Irshad · Sergey Zakharov · Vitor Guizilini · Adrien Gaidon · Zsolt Kira · Rares Ambrus
European Conference on Computer Vision, ECCV 2024

Toyota Research Institute   |   Georgia Institute of Technology

💡 Highlights

  • NeRF-MAE: The first large-scale pretraining utilizing Neural Radiance Fields (NeRF) as an input modality. We pretrain a single Transformer model on thousands of NeRFs for 3D representation learning.
  • NeRF-MAE Dataset: A large-scale NeRF pretraining and downstream task finetuning dataset.

🏷️ TODO 🚀

  • Release large-scale pretraining code 🚀
  • Release NeRF-MAE dataset comprising radiance and density grids 🚀
  • Release 3D object detection finetuning and eval code 🚀
  • Pretrained NeRF-MAE checkpoints and out-of-the-box model usage 🚀

NeRF-MAE Model Architecture

Citation

If you find this repository or our dataset useful, please star ⭐ this repository and consider citing 📝:

@inproceedings{irshad2024nerfmae,
    title={NeRF-MAE: Masked AutoEncoders for Self-Supervised 3D Representation Learning for Neural Radiance Fields},
    author={Muhammad Zubair Irshad and Sergey Zakharov and Vitor Guizilini and Adrien Gaidon and Zsolt Kira and Rares Ambrus},
    booktitle={European Conference on Computer Vision (ECCV)},
    year={2024}
    }

Contents

🌇 Environment

Create a python 3.7 virtual environment and install requirements:

cd $NeRF-MAE repo
conda create -n nerf-mae python=3.9
conda activate nerf-mae
pip install --upgrade pip
pip install -r requirements.txt
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html

The code was built and tested on cuda 11.3

Compile CUDA extension, to run downstream task finetuning, as described in NeRF-RPN:

cd $NeRF-MAE repo
cd nerf_rpn/model/rotated_iou/cuda_op
python setup.py install
cd ../../../..

⛳ Model Usage and Checkpoints

NeRF-MAE is structured to provide easy access to pretrained NeRF-MAE models (and reproductions), to facilitate use for various downstream tasks. This is for extracting good visual features from NeRFs if you don't have resources for large-scale pretraining. Our pretraining provides an easy-to-access embedding of any NeRF scene, which can be used for a variety of downstream tasks in a straightforwaed way.

We have released pretrained and finetuned checkpoints to start using our codebase out-of-the-box. There are two types of usages. 1. Most common one is using the features directly in a downstream task such as an FPN head for 3D Object Detection and 2. Reconstruct the original grid for enforcing losses such as masked reconstruction loss. Below is a sample useage of our model with spelled out comments.

  1. Get the features to be used in a downstream task
import torch

# Define Swin Transformer configurations
swin_config = {
    "swin_t": {"embed_dim": 96, "depths": [2, 2, 6, 2], "num_heads": [3, 6, 12, 24]},
    "swin_s": {"embed_dim": 96, "depths": [2, 2, 18, 2], "num_heads": [3, 6, 12, 24]},
    "swin_b": {"embed_dim": 128, "depths": [2, 2, 18, 2], "num_heads": [3, 6, 12, 24]},
    "swin_l": {"embed_dim": 192, "depths": [2, 2, 18, 2], "num_heads": [6, 12, 24, 48]},
}

# Set the desired backbone type
backbone_type = "swin_s"
config = swin_config[backbone_type]

# Initialize Swin Transformer model
model = SwinTransformer_MAE3D_New(
    patch_size=[4, 4, 4],
    embed_dim=config["embed_dim"],
    depths=config["depths"],
    num_heads=config["num_heads"],
    window_size=[4, 4, 4],
    stochastic_depth_prob=0.1,
    expand_dim=True,
    resolution=resolution,
)

# Load checkpoint and remove unused layers
checkpoint_path = hf_hub_download(repo_id="mirshad7/NeRF-MAE", filename="nerf_mae_pretrained.pt")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])
for attr in ["decoder4", "decoder3", "decoder2", "decoder1", "out", "mask_token"]:
    delattr(model, attr)

# Extract features using Swin Transformer backbone. input_grid has sample shape torch.randn((1, 4, 160, 160, 160))
features = []
input_grid = model.patch_partition(input_grid) + model.pos_embed.type_as(input_grid).to(input_grid.device).clone().detach()
for stage in model.stages:
    input_grid = stage(input_grid)
    features.append(torch.permute(input_grid, [0, 4, 1, 2, 3]).contiguous())  # Format: [N, C, H, W, D]

#Multi-scale features have shape:  [torch.Size([1, 96, 40, 40, 40]), torch.Size([1, 192, 20, 20, 20]), torch.Size([1, 384, 10, 10, 10]), torch.Size([1, 768, 5, 5, 5])] 

# Process features through FPN
  1. Get the Original Grid Output
import torch
# Load data from the specified folder and filename with the given resolution.
res, rgbsigma = load_data(folder_name, filename, resolution=args.resolution)

# rgbsigma has sample shape torch.randn((1, 4, 160, 160, 160))

# Build the model using provided arguments.
model = build_model(args)

# Load checkpoint if provided.
if args.checkpoint:
    model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"])
    model.eval()  # Set model to evaluation mode.

# Run inference getting the features out for downsteam usage
with torch.no_grad():
    pred = model([rgbsigma], is_eval=True)[3]  # Extract only predictions.

1. How to plug these features for downstream 3D bounding detection from NeRFs (i.e. plug-and-play with a NeRF-RPN OBB prediction head)

Please also see the section on Finetuning. Our released finetuned checkpoint achieves state-of-the-art on 3D object detection in NeRFs. To run evaluation using our finetuned checkpoint on the dataset provided by NeRF-RPN, please run the below script, after updating the paths to the pretrained checkpoint i.e. --checkpoint and DATA_ROOT depending on evaluation done for Front3D or Scannet:

bash test_fcos_pretrained.sh

Also see the cooresponding run file i.e. run_fcos_pretrained.py and our model adaptation i.e. SwinTransformer_FPN_Pretrained_Skip. This is a minimal adaptation to plug and play our weights with a NeRF-RPN architecture and achieve significant boost in performance.

🗂️ Dataset

Download the preprocessed datasets here.

Extract pretraining and finetuning dataset under NeRF-MAE/datasets. The directory structure should look like this:

NeRF-MAE
├── pretrain
│   ├── features
│   └── nerfmae_split.npz
└── finetune
    └── front3d_rpn_data
        ├── features
        ├── aabb
        └── obb

Note: The above datasets are all you need to train and evaluate our method. Bonus: we will be releasing our multi-view rendered posed RGB images from FRONT3D, HM3D and Hypersim as well as Instant-NGP trained checkpoints soon (these comprise over 1M+ images and 3k+ NeRF checkpoints)

Please note that our dataset was generated using the instruction from NeRF-RPN and 3D-CLR. Please consider citing our work, NeRF-RPN and 3D-CLR if you find this dataset useful in your research.

Please also note that our dataset uses Front3D, Habitat-Matterport3D, HyperSim and ScanNet as the base version of the dataset i.e. we train a NeRF per scene and extract radiance and desnity grid as well as aligned NeRF-grid 3D annotations. Please read the term of use for each dataset if you want to utilize the posed multi-view images for each of these datasets.

For More details, please checkout out Paper, Github and Project Page!


license: cc-by-nc-4.0