Edit model card

DPOT: Auto-Regressive Denoising Operator Transformer for Large-Scale PDE Pre-Training (ICML'2024)

Code for paper DPOT: Auto-Regressive Denoising Operator Transformer for Large-Scale PDE Pre-Training (ICML'2024). It pretrains neural operator transformers (from 7M to 1B) on multiple PDE datasets. We will release the pre-trained weights soon.

fig1

Our pre-trained DPOT achieves the state-of-the-art performance on multiple PDE datasets and could be used for finetuning on different types of downstream PDE problems.

fig2

Pre-trained Model Configuration

We have five pre-trained checkpoints of different sizes.

Size Attention dim MLP dim Layers Heads Model size
Tiny 512 512 4 4 7M
Small 1024 1024 6 8 30M
Medium 1024 4096 12 8 122M
Large 1536 6144 24 16 509M
Huge 2048 8092 27 8 1.03B

Loading pre-trained model

Here is an example code of loading pre-trained model.

model = DPOTNet(img_size=128, patch_size=8, mixing_type='afno', in_channels=4, in_timesteps=10, out_timesteps=1, out_channels=4, normalize=False, embed_dim=512, modes=32, depth=4, n_blocks=4, mlp_ratio=1, out_layer_dim=32, n_cls=12)
model.load_state_dict(torch.load('model_Ti.pth')['model'])

Datasets

All datasets are stored using hdf5 format, containing data field. Some datasets are stored with individual hdf5 files, others are stored within a single hdf5 file.

Download the original file from these sources and preprocess them to /data folder.

Dataset Link
FNO data Here
PDEBench data Here
PDEArena data Here
CFDbench data Here

Citation

If you use DPOT in your research, please use the following BibTeX entry.

@article{hao2024dpot,
  title={DPOT: Auto-Regressive Denoising Operator Transformer for Large-Scale PDE Pre-Training},
  author={Hao, Zhongkai and Su, Chang and Liu, Songming and Berner, Julius and Ying, Chengyang and Su, Hang and Anandkumar, Anima and Song, Jian and Zhu, Jun},
  journal={arXiv preprint arXiv:2403.03542},
  year={2024}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .