



Project Page
NeRF-MAE : Masked AutoEncoders for Self-Supervised 3D Representation Learning for Neural Radiance Fields
Muhammad Zubair Irshad
Β·
Sergey Zakharov
Β·
Vitor Guizilini
Β·
Adrien Gaidon
Β·
Zsolt Kira
Β·
Rares Ambrus
European Conference on Computer Vision, ECCV 2024
Toyota Research Institute | Georgia Institute of Technology
π‘ Highlights
- NeRF-MAE: The first large-scale pretraining utilizing Neural Radiance Fields (NeRF) as an input modality. We pretrain a single Transformer model on thousands of NeRFs for 3D representation learning.
- NeRF-MAE Dataset: A large-scale NeRF pretraining and downstream task finetuning dataset.
π·οΈ TODO π
- Release large-scale pretraining code π
- Release NeRF-MAE dataset comprising radiance and density grids π
- Release 3D object detection finetuning and eval code π
- Pretrained NeRF-MAE checkpoints and out-of-the-box model usage π
NeRF-MAE Model Architecture
Citation
If you find this repository or our dataset useful, please star β this repository and consider citing π:
@inproceedings{irshad2024nerfmae,
title={NeRF-MAE: Masked AutoEncoders for Self-Supervised 3D Representation Learning for Neural Radiance Fields},
author={Muhammad Zubair Irshad and Sergey Zakharov and Vitor Guizilini and Adrien Gaidon and Zsolt Kira and Rares Ambrus},
booktitle={European Conference on Computer Vision (ECCV)},
year={2024}
}
Contents
π Environment
Create a python 3.7 virtual environment and install requirements:
cd $NeRF-MAE repo
conda create -n nerf-mae python=3.9
conda activate nerf-mae
pip install --upgrade pip
pip install -r requirements.txt
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html
The code was built and tested on cuda 11.3
Compile CUDA extension, to run downstream task finetuning, as described in NeRF-RPN:
cd $NeRF-MAE repo
cd nerf_rpn/model/rotated_iou/cuda_op
python setup.py install
cd ../../../..
β³ Model Usage and Checkpoints
NeRF-MAE is structured to provide easy access to pretrained NeRF-MAE models (and reproductions), to facilitate use for various downstream tasks. This is for extracting good visual features from NeRFs if you don't have resources for large-scale pretraining. Our pretraining provides an easy-to-access embedding of any NeRF scene, which can be used for a variety of downstream tasks in a straightforwaed way.
We have released pretrained and finetuned checkpoints to start using our codebase out-of-the-box. There are two types of usages. 1. Most common one is using the features directly in a downstream task such as an FPN head for 3D Object Detection and 2. Reconstruct the original grid for enforcing losses such as masked reconstruction loss. Below is a sample useage of our model with spelled out comments.
- Get the features to be used in a downstream task
import torch
# Define Swin Transformer configurations
swin_config = {
"swin_t": {"embed_dim": 96, "depths": [2, 2, 6, 2], "num_heads": [3, 6, 12, 24]},
"swin_s": {"embed_dim": 96, "depths": [2, 2, 18, 2], "num_heads": [3, 6, 12, 24]},
"swin_b": {"embed_dim": 128, "depths": [2, 2, 18, 2], "num_heads": [3, 6, 12, 24]},
"swin_l": {"embed_dim": 192, "depths": [2, 2, 18, 2], "num_heads": [6, 12, 24, 48]},
}
# Set the desired backbone type
backbone_type = "swin_s"
config = swin_config[backbone_type]
# Initialize Swin Transformer model
model = SwinTransformer_MAE3D_New(
patch_size=[4, 4, 4],
embed_dim=config["embed_dim"],
depths=config["depths"],
num_heads=config["num_heads"],
window_size=[4, 4, 4],
stochastic_depth_prob=0.1,
expand_dim=True,
resolution=resolution,
)
# Load checkpoint and remove unused layers
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])
for attr in ["decoder4", "decoder3", "decoder2", "decoder1", "out", "mask_token"]:
delattr(model, attr)
# Extract features using Swin Transformer backbone. input_grid has sample shape torch.randn((1, 4, 160, 160, 160))
features = []
input_grid = model.patch_partition(input_grid) + model.pos_embed.type_as(input_grid).to(input_grid.device).clone().detach()
for stage in model.stages:
input_grid = stage(input_grid)
features.append(torch.permute(input_grid, [0, 4, 1, 2, 3]).contiguous()) # Format: [N, C, H, W, D]
#Multi-scale features have shape: [torch.Size([1, 96, 40, 40, 40]), torch.Size([1, 192, 20, 20, 20]), torch.Size([1, 384, 10, 10, 10]), torch.Size([1, 768, 5, 5, 5])]
# Process features through FPN
- Get the Original Grid Output
import torch
# Load data from the specified folder and filename with the given resolution.
res, rgbsigma = load_data(folder_name, filename, resolution=args.resolution)
# rgbsigma has sample shape torch.randn((1, 4, 160, 160, 160))
# Build the model using provided arguments.
model = build_model(args)
# Load checkpoint if provided.
if args.checkpoint:
model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"])
model.eval() # Set model to evaluation mode.
# Run inference getting the features out for downsteam usage
with torch.no_grad():
pred = model([rgbsigma], is_eval=True)[3] # Extract only predictions.
1. How to plug these features for downstream 3D bounding detection from NeRFs (i.e. plug-and-play with a NeRF-RPN OBB prediction head)
Please also see the section on Finetuning. Our released finetuned checkpoint achieves state-of-the-art on 3D object detection in NeRFs. To run evaluation using our finetuned checkpoint on the dataset provided by NeRF-RPN, please run the below script, after updating the paths to the pretrained checkpoint i.e. --checkpoint and DATA_ROOT depending on evaluation done for Front3D
or Scannet
:
bash test_fcos_pretrained.sh
Also see the cooresponding run file i.e. run_fcos_pretrained.py
and our model adaptation i.e. SwinTransformer_FPN_Pretrained_Skip
. This is a minimal adaptation to plug and play our weights with a NeRF-RPN architecture and achieve significant boost in performance.
ποΈ Dataset
Download the preprocessed datasets here.
- Pretraining dataset (comprising NeRF radiance and density grids). Download link
- Finetuning dataset (comprising NeRF radiance and density grids and bounding box/semantic labelling annotations). 3D Object Detection (Provided by NeRF-RPN), 3D Semantic Segmentation (Coming Soon), Voxel-Super Resolution (Coming Soon)
Extract pretraining and finetuning dataset under NeRF-MAE/datasets
. The directory structure should look like this:
NeRF-MAE
βββ pretrain
β βββ features
β βββ nerfmae_split.npz
βββ finetune
βββ front3d_rpn_data
βββ features
βββ aabb
βββ obb
Note: The above datasets are all you need to train and evaluate our method. Bonus: we will be releasing our multi-view rendered posed RGB images from FRONT3D, HM3D and Hypersim as well as Instant-NGP trained checkpoints soon (these comprise over 1M+ images and 3k+ NeRF checkpoints)
Please note that our dataset was generated using the instruction from NeRF-RPN and 3D-CLR. Please consider citing our work, NeRF-RPN and 3D-CLR if you find this dataset useful in your research.
Please also note that our dataset uses Front3D, Habitat-Matterport3D, HyperSim and ScanNet as the base version of the dataset i.e. we train a NeRF per scene and extract radiance and desnity grid as well as aligned NeRF-grid 3D annotations. Please read the term of use for each dataset if you want to utilize the posed multi-view images for each of these datasets.