# JaxNeRF
This is a [JAX](https://github.com/google/jax) implementation of
[NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://www.matthewtancik.com/nerf).
This code is created and maintained by
[Boyang Deng](https://boyangdeng.com/),
[Jon Barron](https://jonbarron.info/),
and [Pratul Srinivasan](https://people.eecs.berkeley.edu/~pratul/).
Our JAX implementation currently supports:
Platform |
Single-Host GPU |
Multi-Device TPU |
Type |
Single-Device |
Multi-Device |
Single-Host |
Multi-Host |
Training |
|
|
|
|
Evaluation |
|
|
|
|
The training job on 128 TPUv2 cores can be done in **2.5 hours (v.s 3 days for TF
NeRF)** for 1 million optimization steps. In other words, JaxNeRF trains to the best while trains very fast.
As for inference speed, here are the statistics of rendering an image with
800x800 resolution (numbers are averaged over 50 rendering passes):
| Platform | 1 x NVIDIA V100 | 8 x NVIDIA V100 | 128 x TPUv2 |
|----------|:---------------:|:-----------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------:|
| TF NeRF | 27.74 secs | | |
| JaxNeRF | 20.77 secs | 2.65 secs | 0.35 secs |
The code is tested and reviewed carefully to match the
[original TF NeRF implementation](https://github.com/bmild/nerf).
If you have any issues using this code, please do not open an issue as the repo
is shared by all projects under Google Research. Instead, just email
jaxnerf@google.com.
## Installation
We recommend using [Anaconda](https://www.anaconda.com/products/individual) to set
up the environment. Run the following commands:
```
# Clone the repo
svn export https://github.com/google-research/google-research/trunk/jaxnerf
# Create a conda environment, note you can use python 3.6-3.8 as
# one of the dependencies (TensorFlow) hasn't supported python 3.9 yet.
conda create --name jaxnerf python=3.6.12; conda activate jaxnerf
# Prepare pip
conda install pip; pip install --upgrade pip
# Install requirements
pip install -r jaxnerf/requirements.txt
# [Optional] Install GPU and TPU support for Jax
# Remember to change cuda101 to your CUDA version, e.g. cuda110 for CUDA 11.0.
pip install --upgrade jax jaxlib==0.1.57+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html
```
Then, you'll need to download the datasets
from the [NeRF official Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1).
Please download the `nerf_synthetic.zip` and `nerf_llff_data.zip` and unzip them
in the place you like. Let's assume they are placed under `/tmp/jaxnerf/data/`.
That's it for installation. You're good to go. **Notice:** For the following instructions, you don't need to enter the jaxnerf folder. Just stay in the parent folder.
## Two Commands for Everything
```
bash jaxnerf/train.sh demo /tmp/jaxnerf/data
bash jaxnerf/eval.sh demo /tmp/jaxnerf/data
```
Once both jobs are done running (which may take a while if you only have 1 GPU
or CPU), you'll have a folder, `/tmp/jaxnerf/data/demo`, with:
* Trained NeRF models for all scenes in the blender dataset.
* Rendered images and depth maps for all test views.
* The collected PSNRs of all scenes in a TXT file.
Note that we used the `demo` config here which is basically the `blender` config
in the paper except smaller batch size and much less train steps. Of course, you
can use other configs to replace `demo` and other data locations to replace
`/tmp/jaxnerf/data`.
We provide 2 configurations in the folder `configs` which match the original
configurations used in the paper for the blender dataset and the LLFF dataset.
Be careful when you use them. Their batch sizes are large so you may get OOM error if you have limited resources, for example, 1 GPU with small memory. Also, they have many many train steps so you may need days to finish training all scenes.
## Play with One Scene
You can also train NeRF on only one scene. The easiest way is to use given configs:
```
python -m jaxnerf.train \
--data_dir=/PATH/TO/YOUR/SCENE/DATA \
--train_dir=/PATH/TO/THE/PLACE/YOU/WANT/TO/SAVE/CHECKPOINTS \
--config=configs/CONFIG_YOU_LIKE
```
Evaluating NeRF on one scene is similar:
```
python -m jaxnerf.eval \
--data_dir=/PATH/TO/YOUR/SCENE/DATA \
--train_dir=/PATH/TO/THE/PLACE/YOU/SAVED/CHECKPOINTS \
--config=configs/CONFIG_YOU_LIKE \
--chunk=4096
```
The `chunk` parameter defines how many rays are feed to the model in one go.
We recommend you to use the largest value that fits to your device's memory but
small values are fine, only a bit slow.
You can also define your own configurations by passing command line flags. Please refer to the `define_flags` function in `nerf/utils.py` for all the flags and their meanings.
**Note**: For the ficus scene in the blender dataset, we noticed that it's sensible to different initializations,
e.g. using different random seeds, if using the original learning rate schedule in the paper.
Therefore, we provide a simple tweak (turned off by default) for more stable trainings: using `lr_delay_steps` and `lr_delay_mult`.
This allows the training to start from a smaller learning rate (`lr_init` * `lr_delay_mult`) in the first `lr_delay_steps`.
We didn't use them for our pretrained models
but we tested `lr_delay_steps=5000` with `lr_delay_mult=0.2` and it works quite smoothly.
## Pretrained Models
We provide a collection of pretrained NeRF models that match the numbers
reported in the [paper](https://arxiv.org/abs/2003.08934). Actually, ours are
slightly better overall because we trained for more iterations (while still
being much faster!). You can find our pretrained models
[here](http://storage.googleapis.com/gresearch/jaxnerf/jaxnerf_pretrained_models.zip).
The performances (in PSNR) of our pretrained NeRF models are listed below:
### Blender
| Scene | Chair | Drums | Ficus | Hotdog | Lego | Materials | Mic | Ship | Mean |
|---------|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|
| TF NeRF | 33.00 | 25.01 | 30.13 | 36.18 | 32.54 | 29.62 | 32.91 | 28.65 | 31.01 |
| JaxNeRF | **34.08** | **25.03** | **30.43** | **36.92** | **33.28** | **29.91** | **34.53** | **29.36** | **31.69** |
### LLFF
| Scene | Room | Fern | Leaves | Fortress | Orchids | Flower | T-Rex | Horns | Mean |
|---------|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|
| TF NeRF | 32.70 | **25.17** | 20.92 | 31.16 | **20.36** | 27.40 | 26.80 | 27.45 | 26.50 |
| JaxNeRF | **33.04** | 24.83 | **21.23** | **31.76** | 20.27 | **28.07** | **27.42** | **28.10** | **26.84** |
## Citation
If you use this software package, please cite it as:
```
@software{jaxnerf2020github,
author = {Boyang Deng and Jonathan T. Barron and Pratul P. Srinivasan},
title = {{JaxNeRF}: an efficient {JAX} implementation of {NeRF}},
url = {https://github.com/google-research/google-research/tree/master/jaxnerf},
version = {0.0},
year = {2020},
}
```
and also cite the original NeRF paper:
```
@inproceedings{mildenhall2020nerf,
title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
year={2020},
booktitle={ECCV},
}
```
## Acknowledgement
We'd like to thank
[Daniel Duckworth](http://www.stronglyconvex.com/),
[Dan Gnanapragasam](https://research.google/people/DanGnanapragasam/),
and [James Bradbury](https://twitter.com/jekbradbury)
for their help on reviewing and optimizing this code.
We'd like to also thank the amazing [JAX](https://github.com/google/jax) team for
very insightful and helpful discussions on how to use JAX for NeRF.