diff --git a/.huggingface/.gitignore b/.huggingface/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f59ec20aabf5842d237244ece8c81ab184faeac1 --- /dev/null +++ b/.huggingface/.gitignore @@ -0,0 +1 @@ +* \ No newline at end of file diff --git a/.huggingface/download/.gitattributes.lock b/.huggingface/download/.gitattributes.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/.gitattributes.metadata b/.huggingface/download/.gitattributes.metadata new file mode 100644 index 0000000000000000000000000000000000000000..f534382fc686a3e0adbd51fc13e422378ee6b324 --- /dev/null +++ b/.huggingface/download/.gitattributes.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +5d7f077bd6e1a90e4cb8544726b05f855a1e0d13 +1723652257.0751407 diff --git a/.huggingface/download/README.md.lock b/.huggingface/download/README.md.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/README.md.metadata b/.huggingface/download/README.md.metadata new file mode 100644 index 0000000000000000000000000000000000000000..56bd14420994489959bf71dbcbe3e8867ad420a8 --- /dev/null +++ b/.huggingface/download/README.md.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +73a7808dee60f254271c2a2e61364c9a4679842b +1723652257.100425 diff --git a/.huggingface/download/app.py.lock b/.huggingface/download/app.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/app.py.metadata b/.huggingface/download/app.py.metadata new file mode 100644 index 0000000000000000000000000000000000000000..895b884de95b05d5c31d0f31edcb658bfd5aeb96 --- /dev/null +++ b/.huggingface/download/app.py.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +869732a935815773960fec8cd94428792ba3f924 +1723652257.0650802 diff --git a/.huggingface/download/examples/captured_p.webp.lock b/.huggingface/download/examples/captured_p.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/captured_p.webp.metadata b/.huggingface/download/examples/captured_p.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..84e056fd03d4911b03ed62d158969d2d0616e152 --- /dev/null +++ b/.huggingface/download/examples/captured_p.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +a0b83623bd3e8528385869ca5370eb5fc886b6f5 +1723652257.0598044 diff --git a/.huggingface/download/examples/chair_p.webp.lock b/.huggingface/download/examples/chair_p.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/chair_p.webp.metadata b/.huggingface/download/examples/chair_p.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..2bfd2c82bd5e757b3c0ce114f118c75c356f2509 --- /dev/null +++ b/.huggingface/download/examples/chair_p.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +a7812c95b549d96ed0495ca85af88441d3f9d457 +1723652257.061323 diff --git a/.huggingface/download/examples/flamingo_p.webp.lock b/.huggingface/download/examples/flamingo_p.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/flamingo_p.webp.metadata b/.huggingface/download/examples/flamingo_p.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..308b248bdac3dfd20edeba3941c2e74c39eaec46 --- /dev/null +++ b/.huggingface/download/examples/flamingo_p.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +5b0164267efdcd7aa67dc4993ba0f28f05b75d62 +1723652257.0590506 diff --git a/.huggingface/download/examples/hamburger_p.webp.lock b/.huggingface/download/examples/hamburger_p.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/hamburger_p.webp.metadata b/.huggingface/download/examples/hamburger_p.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..c6c3438bfa4ccfa08cd00a94d9a9ead8b5b0ac45 --- /dev/null +++ b/.huggingface/download/examples/hamburger_p.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +7e5b467a5bf2c52c67e34ae55457cd7940884fe9 +1723652257.0512598 diff --git a/.huggingface/download/examples/horse_p.webp.lock b/.huggingface/download/examples/horse_p.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/horse_p.webp.metadata b/.huggingface/download/examples/horse_p.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..6e92f72b901c4d5fe0228870243e005d49c7943e --- /dev/null +++ b/.huggingface/download/examples/horse_p.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +c3bd936b078a421f12321f12fc72a45936dac1f3 +1723652257.069789 diff --git a/.huggingface/download/examples/iso_house.webp.lock b/.huggingface/download/examples/iso_house.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/iso_house.webp.metadata b/.huggingface/download/examples/iso_house.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..76e129c6b5bb2921455ee24a5414795944f3bc7b --- /dev/null +++ b/.huggingface/download/examples/iso_house.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +1d8d71b4417299de911fedfe7a89118eefd7103a +1723652257.4452686 diff --git a/.huggingface/download/examples/marble_p.webp.lock b/.huggingface/download/examples/marble_p.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/marble_p.webp.metadata b/.huggingface/download/examples/marble_p.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..5f40ebb25ae513abbc67beb9926f4a8debd13d16 --- /dev/null +++ b/.huggingface/download/examples/marble_p.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +548a0862bbedf4fca1bd243021703481e81443b5 +1723652257.4804113 diff --git a/.huggingface/download/examples/police_woman_p.webp.lock b/.huggingface/download/examples/police_woman_p.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/police_woman_p.webp.metadata b/.huggingface/download/examples/police_woman_p.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..21cf7ad94c8c0ae903e2b229614928cb715c7e00 --- /dev/null +++ b/.huggingface/download/examples/police_woman_p.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +a860e1cfe4e28a9cae59239e851cead16193e4b8 +1723652257.419765 diff --git a/.huggingface/download/examples/poly_fox.webp.lock b/.huggingface/download/examples/poly_fox.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/poly_fox.webp.metadata b/.huggingface/download/examples/poly_fox.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..3876446e8fbd33cf057308d5228c77032f5593e9 --- /dev/null +++ b/.huggingface/download/examples/poly_fox.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +e749c683447d9d51dd17671a87c406ac8abba6ba +1723652257.612526 diff --git a/.huggingface/download/examples/robot_p.webp.lock b/.huggingface/download/examples/robot_p.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/robot_p.webp.metadata b/.huggingface/download/examples/robot_p.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..ad85c04c64e4e95b221bc3d989697b35856923ef --- /dev/null +++ b/.huggingface/download/examples/robot_p.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +ffe480ebfe0216b8701aa35e29152ac778ca7c57 +1723652257.4218733 diff --git a/.huggingface/download/examples/teapot.webp.lock b/.huggingface/download/examples/teapot.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/teapot.webp.metadata b/.huggingface/download/examples/teapot.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..e2751eabb873b31777afc87cff5f8aa01ff7b81f --- /dev/null +++ b/.huggingface/download/examples/teapot.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +b6f13c13498eceba7583a95ef7b7a087f84fc73d +1723652257.4243357 diff --git a/.huggingface/download/examples/tiger_girl.webp.lock b/.huggingface/download/examples/tiger_girl.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/tiger_girl.webp.metadata b/.huggingface/download/examples/tiger_girl.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..326a1003dbc5913344b03c80399b3ad4ecf060d0 --- /dev/null +++ b/.huggingface/download/examples/tiger_girl.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +489d048b5af2d805cb1ae41eef18a2d8abcb35aa +1723652257.461125 diff --git a/.huggingface/download/examples/unicorn_p.webp.lock b/.huggingface/download/examples/unicorn_p.webp.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/examples/unicorn_p.webp.metadata b/.huggingface/download/examples/unicorn_p.webp.metadata new file mode 100644 index 0000000000000000000000000000000000000000..5aca62c94aec54efa6d313ae8c845e721d2c54ad --- /dev/null +++ b/.huggingface/download/examples/unicorn_p.webp.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +333d17406c01cfd93b0d7790926db6519bd2df4c +1723652257.4432185 diff --git a/.huggingface/download/requirements.txt.lock b/.huggingface/download/requirements.txt.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/requirements.txt.metadata b/.huggingface/download/requirements.txt.metadata new file mode 100644 index 0000000000000000000000000000000000000000..2b57eede632c3e3d0c477de945522c9995d64b0c --- /dev/null +++ b/.huggingface/download/requirements.txt.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +c35dd54b42673760547e7b5f03f48d5ff67a7437 +1723652257.9081738 diff --git a/.huggingface/download/tsr/__pycache__/system.cpython-310.pyc.lock b/.huggingface/download/tsr/__pycache__/system.cpython-310.pyc.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/__pycache__/system.cpython-310.pyc.metadata b/.huggingface/download/tsr/__pycache__/system.cpython-310.pyc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..849e51f63498b1835762ba63e92a2f5c34d7f6fc --- /dev/null +++ b/.huggingface/download/tsr/__pycache__/system.cpython-310.pyc.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +e959f300ee2347d00c603bd99f9bd867dadf4499 +1723652257.8178537 diff --git a/.huggingface/download/tsr/__pycache__/utils.cpython-310.pyc.lock b/.huggingface/download/tsr/__pycache__/utils.cpython-310.pyc.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/__pycache__/utils.cpython-310.pyc.metadata b/.huggingface/download/tsr/__pycache__/utils.cpython-310.pyc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..c90bf6d550752a2abe42dd886036eed9ad38f0ce --- /dev/null +++ b/.huggingface/download/tsr/__pycache__/utils.cpython-310.pyc.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +4fcac64634b0d8e7f5f48ece0f9bd046ca3bedbb +1723652257.8950891 diff --git a/.huggingface/download/tsr/models/__pycache__/camera.cpython-310.pyc.lock b/.huggingface/download/tsr/models/__pycache__/camera.cpython-310.pyc.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/__pycache__/camera.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/__pycache__/camera.cpython-310.pyc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..386b36ecea6d7e792c4257c3e09c6ee61a51c459 --- /dev/null +++ b/.huggingface/download/tsr/models/__pycache__/camera.cpython-310.pyc.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +cf99205bb6f64375040f3c92d02df1f48695385c +1723652258.0100667 diff --git a/.huggingface/download/tsr/models/__pycache__/isosurface.cpython-310.pyc.lock b/.huggingface/download/tsr/models/__pycache__/isosurface.cpython-310.pyc.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/__pycache__/isosurface.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/__pycache__/isosurface.cpython-310.pyc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..01801a02b4a3a703a85aae56a3b5c2a661a9f7e3 --- /dev/null +++ b/.huggingface/download/tsr/models/__pycache__/isosurface.cpython-310.pyc.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +5858895751ab820d26e557b5e9dfebc41679cae1 +1723652257.9176967 diff --git a/.huggingface/download/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc.lock b/.huggingface/download/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..75f312f908dd2a6393b444a60923544ba971094b --- /dev/null +++ b/.huggingface/download/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +f165c9fdb535e8888a8ab3393c6f6b8d7deb9065 +1723652257.9275925 diff --git a/.huggingface/download/tsr/models/__pycache__/network_utils.cpython-310.pyc.lock b/.huggingface/download/tsr/models/__pycache__/network_utils.cpython-310.pyc.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/__pycache__/network_utils.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/__pycache__/network_utils.cpython-310.pyc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..a29c0aed75b71b461bb5f9fcf6ba3667f1c47482 --- /dev/null +++ b/.huggingface/download/tsr/models/__pycache__/network_utils.cpython-310.pyc.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +9681517e0c22f29e8b5bc02aaf2a29f06087e554 +1723652257.8843312 diff --git a/.huggingface/download/tsr/models/isosurface.py.lock b/.huggingface/download/tsr/models/isosurface.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/isosurface.py.metadata b/.huggingface/download/tsr/models/isosurface.py.metadata new file mode 100644 index 0000000000000000000000000000000000000000..11d635ad9c5bc5b7ab797716a3cdf3d7a8b74320 --- /dev/null +++ b/.huggingface/download/tsr/models/isosurface.py.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +39c8389b5562f5ebb787ef85bcfff56d85aa51db +1723652258.0929828 diff --git a/.huggingface/download/tsr/models/nerf_renderer.py.lock b/.huggingface/download/tsr/models/nerf_renderer.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/nerf_renderer.py.metadata b/.huggingface/download/tsr/models/nerf_renderer.py.metadata new file mode 100644 index 0000000000000000000000000000000000000000..6871c9be20e24cf2f1d444fbe61a39f9acb74c90 --- /dev/null +++ b/.huggingface/download/tsr/models/nerf_renderer.py.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +b2d3e652c423e62c4e54fdf7c0751602f51b107b +1723652258.3381011 diff --git a/.huggingface/download/tsr/models/network_utils.py.lock b/.huggingface/download/tsr/models/network_utils.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/network_utils.py.metadata b/.huggingface/download/tsr/models/network_utils.py.metadata new file mode 100644 index 0000000000000000000000000000000000000000..638d7f879fc013f3edd9fa40fe73092bc581d61b --- /dev/null +++ b/.huggingface/download/tsr/models/network_utils.py.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +3844f533bf3b6c9afce6de3857255ee08125b1ba +1723652258.4015312 diff --git a/.huggingface/download/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc.lock b/.huggingface/download/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..b8f886c929be63b932bb58094f3f165a84aceb6b --- /dev/null +++ b/.huggingface/download/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +a492f56dc156938d3250d77f5c182ab05e1655ea +1723652258.4204426 diff --git a/.huggingface/download/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc.lock b/.huggingface/download/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..7e4b7dad51e835a7234aa7d3b9b6ef011b3c96bb --- /dev/null +++ b/.huggingface/download/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +06d33036639d982296542f5d441f991d3e01ddb0 +1723652258.2991421 diff --git a/.huggingface/download/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc.lock b/.huggingface/download/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..e6cdf47e17272f30e45a0c997772301a1764f5bc --- /dev/null +++ b/.huggingface/download/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +31191fd86c3d1668aefd27059738590dc066cf7b +1723652258.3975506 diff --git a/.huggingface/download/tsr/models/tokenizers/image.py.lock b/.huggingface/download/tsr/models/tokenizers/image.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/tokenizers/image.py.metadata b/.huggingface/download/tsr/models/tokenizers/image.py.metadata new file mode 100644 index 0000000000000000000000000000000000000000..cd4a705ca9dcc22addb8107007688fa2c8c4b32e --- /dev/null +++ b/.huggingface/download/tsr/models/tokenizers/image.py.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +34a092e7170c87da53363822af554c78e5f8083f +1723652258.3694685 diff --git a/.huggingface/download/tsr/models/tokenizers/triplane.py.lock b/.huggingface/download/tsr/models/tokenizers/triplane.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/tokenizers/triplane.py.metadata b/.huggingface/download/tsr/models/tokenizers/triplane.py.metadata new file mode 100644 index 0000000000000000000000000000000000000000..d9a4cc267e07edb6e7035af97ba177134a5ab2f4 --- /dev/null +++ b/.huggingface/download/tsr/models/tokenizers/triplane.py.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +ecdd7fd2201c974bb70b18a90a633287b814886f +1723652258.4722672 diff --git a/.huggingface/download/tsr/models/transformer/__pycache__/attention.cpython-310.pyc.lock b/.huggingface/download/tsr/models/transformer/__pycache__/attention.cpython-310.pyc.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/transformer/__pycache__/attention.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/transformer/__pycache__/attention.cpython-310.pyc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..792756604534a357a91f5480a6fb73e212afe977 --- /dev/null +++ b/.huggingface/download/tsr/models/transformer/__pycache__/attention.cpython-310.pyc.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +02bf1262175cdbe2bb9651042c247f53c2ab0e91 +1723652258.4920888 diff --git a/.huggingface/download/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc.lock b/.huggingface/download/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..77bf98742ff9075cd6073d6d9905caa208e24a30 --- /dev/null +++ b/.huggingface/download/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +8c2f3d2a8c487f237c94433bd8aaafd26afb8ce0 +1723652258.853393 diff --git a/.huggingface/download/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc.lock b/.huggingface/download/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..53b37dac48350b199e22225212a2a9667237b50f --- /dev/null +++ b/.huggingface/download/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +39cb242f767c41c73138dda36700ef1554bbf31f +1723652258.8555367 diff --git a/.huggingface/download/tsr/models/transformer/attention.py.lock b/.huggingface/download/tsr/models/transformer/attention.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/transformer/attention.py.metadata b/.huggingface/download/tsr/models/transformer/attention.py.metadata new file mode 100644 index 0000000000000000000000000000000000000000..45ef4e0f4dc77f49f54d0b8c7eae271b3bffa4ca --- /dev/null +++ b/.huggingface/download/tsr/models/transformer/attention.py.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +eb873231e3ad195a7fef3a2c4ef217be3056cd4e +1723652258.8082144 diff --git a/.huggingface/download/tsr/models/transformer/basic_transformer_block.py.lock b/.huggingface/download/tsr/models/transformer/basic_transformer_block.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/transformer/basic_transformer_block.py.metadata b/.huggingface/download/tsr/models/transformer/basic_transformer_block.py.metadata new file mode 100644 index 0000000000000000000000000000000000000000..a1abfd5d05f5836c7ddc1713c7f3a82de169d246 --- /dev/null +++ b/.huggingface/download/tsr/models/transformer/basic_transformer_block.py.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +a55ea6189c9e1da2f86f6d0957ae30f26fefff0a +1723652258.8690145 diff --git a/.huggingface/download/tsr/models/transformer/transformer_1d.py.lock b/.huggingface/download/tsr/models/transformer/transformer_1d.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/models/transformer/transformer_1d.py.metadata b/.huggingface/download/tsr/models/transformer/transformer_1d.py.metadata new file mode 100644 index 0000000000000000000000000000000000000000..336fa56e70e329398341a7d14b98fd9dc8df05f0 --- /dev/null +++ b/.huggingface/download/tsr/models/transformer/transformer_1d.py.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +7546232c5126a6622f1fd701ba36f9d4b53b9178 +1723652258.8556223 diff --git a/.huggingface/download/tsr/system.py.lock b/.huggingface/download/tsr/system.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/system.py.metadata b/.huggingface/download/tsr/system.py.metadata new file mode 100644 index 0000000000000000000000000000000000000000..3315edb602420260e8297a8513defa1ad0abe77c --- /dev/null +++ b/.huggingface/download/tsr/system.py.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +de5ae6ef75afd3bbbe0bdfacc7258a8c51409cc5 +1723652258.860138 diff --git a/.huggingface/download/tsr/utils.py.lock b/.huggingface/download/tsr/utils.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/tsr/utils.py.metadata b/.huggingface/download/tsr/utils.py.metadata new file mode 100644 index 0000000000000000000000000000000000000000..873d54339e4474925e9226ba5021cbe23784a8e2 --- /dev/null +++ b/.huggingface/download/tsr/utils.py.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +6a1b59aef75d02d39b29222d300fe9241bb11444 +1723652258.8708167 diff --git a/.huggingface/download/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl.lock b/.huggingface/download/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.huggingface/download/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl.metadata b/.huggingface/download/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl.metadata new file mode 100644 index 0000000000000000000000000000000000000000..da0763d7e0795a213b3e82b539bab44050db58e0 --- /dev/null +++ b/.huggingface/download/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl.metadata @@ -0,0 +1,3 @@ +97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e +4af160ba1274e2205d3529a7b82efdb6946c2158a78e19631ed840301055b8d6 +1723652259.3841252 diff --git a/examples/captured_p.webp b/examples/captured_p.webp new file mode 100644 index 0000000000000000000000000000000000000000..a0b83623bd3e8528385869ca5370eb5fc886b6f5 Binary files /dev/null and b/examples/captured_p.webp differ diff --git a/examples/chair_p.webp b/examples/chair_p.webp new file mode 100644 index 0000000000000000000000000000000000000000..a7812c95b549d96ed0495ca85af88441d3f9d457 Binary files /dev/null and b/examples/chair_p.webp differ diff --git a/examples/flamingo_p.webp b/examples/flamingo_p.webp new file mode 100644 index 0000000000000000000000000000000000000000..5b0164267efdcd7aa67dc4993ba0f28f05b75d62 Binary files /dev/null and b/examples/flamingo_p.webp differ diff --git a/examples/hamburger_p.webp b/examples/hamburger_p.webp new file mode 100644 index 0000000000000000000000000000000000000000..7e5b467a5bf2c52c67e34ae55457cd7940884fe9 Binary files /dev/null and b/examples/hamburger_p.webp differ diff --git a/examples/horse_p.webp b/examples/horse_p.webp new file mode 100644 index 0000000000000000000000000000000000000000..c3bd936b078a421f12321f12fc72a45936dac1f3 Binary files /dev/null and b/examples/horse_p.webp differ diff --git a/examples/iso_house.webp b/examples/iso_house.webp new file mode 100644 index 0000000000000000000000000000000000000000..1d8d71b4417299de911fedfe7a89118eefd7103a Binary files /dev/null and b/examples/iso_house.webp differ diff --git a/examples/marble_p.webp b/examples/marble_p.webp new file mode 100644 index 0000000000000000000000000000000000000000..548a0862bbedf4fca1bd243021703481e81443b5 Binary files /dev/null and b/examples/marble_p.webp differ diff --git a/examples/police_woman_p.webp b/examples/police_woman_p.webp new file mode 100644 index 0000000000000000000000000000000000000000..a860e1cfe4e28a9cae59239e851cead16193e4b8 Binary files /dev/null and b/examples/police_woman_p.webp differ diff --git a/examples/poly_fox.webp b/examples/poly_fox.webp new file mode 100644 index 0000000000000000000000000000000000000000..e749c683447d9d51dd17671a87c406ac8abba6ba Binary files /dev/null and b/examples/poly_fox.webp differ diff --git a/examples/robot_p.webp b/examples/robot_p.webp new file mode 100644 index 0000000000000000000000000000000000000000..ffe480ebfe0216b8701aa35e29152ac778ca7c57 Binary files /dev/null and b/examples/robot_p.webp differ diff --git a/examples/teapot.webp b/examples/teapot.webp new file mode 100644 index 0000000000000000000000000000000000000000..b6f13c13498eceba7583a95ef7b7a087f84fc73d Binary files /dev/null and b/examples/teapot.webp differ diff --git a/examples/tiger_girl.webp b/examples/tiger_girl.webp new file mode 100644 index 0000000000000000000000000000000000000000..489d048b5af2d805cb1ae41eef18a2d8abcb35aa Binary files /dev/null and b/examples/tiger_girl.webp differ diff --git a/examples/unicorn_p.webp b/examples/unicorn_p.webp new file mode 100644 index 0000000000000000000000000000000000000000..333d17406c01cfd93b0d7790926db6519bd2df4c Binary files /dev/null and b/examples/unicorn_p.webp differ diff --git a/tsr/__pycache__/system.cpython-310.pyc b/tsr/__pycache__/system.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e959f300ee2347d00c603bd99f9bd867dadf4499 Binary files /dev/null and b/tsr/__pycache__/system.cpython-310.pyc differ diff --git a/tsr/__pycache__/utils.cpython-310.pyc b/tsr/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fcac64634b0d8e7f5f48ece0f9bd046ca3bedbb Binary files /dev/null and b/tsr/__pycache__/utils.cpython-310.pyc differ diff --git a/tsr/models/__pycache__/camera.cpython-310.pyc b/tsr/models/__pycache__/camera.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf99205bb6f64375040f3c92d02df1f48695385c Binary files /dev/null and b/tsr/models/__pycache__/camera.cpython-310.pyc differ diff --git a/tsr/models/__pycache__/isosurface.cpython-310.pyc b/tsr/models/__pycache__/isosurface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5858895751ab820d26e557b5e9dfebc41679cae1 Binary files /dev/null and b/tsr/models/__pycache__/isosurface.cpython-310.pyc differ diff --git a/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc b/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f165c9fdb535e8888a8ab3393c6f6b8d7deb9065 Binary files /dev/null and b/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc differ diff --git a/tsr/models/__pycache__/network_utils.cpython-310.pyc b/tsr/models/__pycache__/network_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9681517e0c22f29e8b5bc02aaf2a29f06087e554 Binary files /dev/null and b/tsr/models/__pycache__/network_utils.cpython-310.pyc differ diff --git a/tsr/models/isosurface.py b/tsr/models/isosurface.py new file mode 100644 index 0000000000000000000000000000000000000000..39c8389b5562f5ebb787ef85bcfff56d85aa51db --- /dev/null +++ b/tsr/models/isosurface.py @@ -0,0 +1,48 @@ +from typing import Callable, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +from torchmcubes import marching_cubes + + +class IsosurfaceHelper(nn.Module): + points_range: Tuple[float, float] = (0, 1) + + @property + def grid_vertices(self) -> torch.FloatTensor: + raise NotImplementedError + + +class MarchingCubeHelper(IsosurfaceHelper): + def __init__(self, resolution: int) -> None: + super().__init__() + self.resolution = resolution + self.mc_func: Callable = marching_cubes + self._grid_vertices: Optional[torch.FloatTensor] = None + + @property + def grid_vertices(self) -> torch.FloatTensor: + if self._grid_vertices is None: + # keep the vertices on CPU so that we can support very large resolution + x, y, z = ( + torch.linspace(*self.points_range, self.resolution), + torch.linspace(*self.points_range, self.resolution), + torch.linspace(*self.points_range, self.resolution), + ) + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + verts = torch.cat( + [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1 + ).reshape(-1, 3) + self._grid_vertices = verts + return self._grid_vertices + + def forward( + self, + level: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.LongTensor]: + level = -level.view(self.resolution, self.resolution, self.resolution) + v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0) + v_pos = v_pos[..., [2, 1, 0]] + v_pos = v_pos / (self.resolution - 1.0) + return v_pos, t_pos_idx diff --git a/tsr/models/nerf_renderer.py b/tsr/models/nerf_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d3e652c423e62c4e54fdf7c0751602f51b107b --- /dev/null +++ b/tsr/models/nerf_renderer.py @@ -0,0 +1,180 @@ +from dataclasses import dataclass, field +from typing import Dict + +import torch +import torch.nn.functional as F +from einops import rearrange, reduce + +from ..utils import ( + BaseModule, + chunk_batch, + get_activation, + rays_intersect_bbox, + scale_tensor, +) + + +class TriplaneNeRFRenderer(BaseModule): + @dataclass + class Config(BaseModule.Config): + radius: float + + feature_reduction: str = "concat" + density_activation: str = "trunc_exp" + density_bias: float = -1.0 + color_activation: str = "sigmoid" + num_samples_per_ray: int = 128 + randomized: bool = False + + cfg: Config + + def configure(self) -> None: + assert self.cfg.feature_reduction in ["concat", "mean"] + self.chunk_size = 0 + + def set_chunk_size(self, chunk_size: int): + assert ( + chunk_size >= 0 + ), "chunk_size must be a non-negative integer (0 for no chunking)." + self.chunk_size = chunk_size + + def query_triplane( + self, + decoder: torch.nn.Module, + positions: torch.Tensor, + triplane: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + input_shape = positions.shape[:-1] + positions = positions.view(-1, 3) + + # positions in (-radius, radius) + # normalized to (-1, 1) for grid sample + positions = scale_tensor( + positions, (-self.cfg.radius, self.cfg.radius), (-1, 1) + ) + + def _query_chunk(x): + indices2D: torch.Tensor = torch.stack( + (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]), + dim=-3, + ) + out: torch.Tensor = F.grid_sample( + rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3), + rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3), + align_corners=False, + mode="bilinear", + ) + if self.cfg.feature_reduction == "concat": + out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3) + elif self.cfg.feature_reduction == "mean": + out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean") + else: + raise NotImplementedError + + net_out: Dict[str, torch.Tensor] = decoder(out) + return net_out + + if self.chunk_size > 0: + net_out = chunk_batch(_query_chunk, self.chunk_size, positions) + else: + net_out = _query_chunk(positions) + + net_out["density_act"] = get_activation(self.cfg.density_activation)( + net_out["density"] + self.cfg.density_bias + ) + net_out["color"] = get_activation(self.cfg.color_activation)( + net_out["features"] + ) + + net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()} + + return net_out + + def _forward( + self, + decoder: torch.nn.Module, + triplane: torch.Tensor, + rays_o: torch.Tensor, + rays_d: torch.Tensor, + **kwargs, + ): + rays_shape = rays_o.shape[:-1] + rays_o = rays_o.view(-1, 3) + rays_d = rays_d.view(-1, 3) + n_rays = rays_o.shape[0] + + t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius) + t_near, t_far = t_near[rays_valid], t_far[rays_valid] + + t_vals = torch.linspace( + 0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device + ) + t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0 + z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples) + + xyz = ( + rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :] + ) # (N_rays, N_sample, 3) + + mlp_out = self.query_triplane( + decoder=decoder, + positions=xyz, + triplane=triplane, + ) + + eps = 1e-10 + # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples) + deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples) + alpha = 1 - torch.exp( + -deltas * mlp_out["density_act"][..., 0] + ) # (N_rays, N_samples) + accum_prod = torch.cat( + [ + torch.ones_like(alpha[:, :1]), + torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1), + ], + dim=-1, + ) + weights = alpha * accum_prod # (N_rays, N_samples) + comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3) + opacity_ = weights.sum(dim=-1) # (N_rays) + + comp_rgb = torch.zeros( + n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device + ) + opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device) + comp_rgb[rays_valid] = comp_rgb_ + opacity[rays_valid] = opacity_ + + comp_rgb += 1 - opacity[..., None] + comp_rgb = comp_rgb.view(*rays_shape, 3) + + return comp_rgb + + def forward( + self, + decoder: torch.nn.Module, + triplane: torch.Tensor, + rays_o: torch.Tensor, + rays_d: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + if triplane.ndim == 4: + comp_rgb = self._forward(decoder, triplane, rays_o, rays_d) + else: + comp_rgb = torch.stack( + [ + self._forward(decoder, triplane[i], rays_o[i], rays_d[i]) + for i in range(triplane.shape[0]) + ], + dim=0, + ) + + return comp_rgb + + def train(self, mode=True): + self.randomized = mode and self.cfg.randomized + return super().train(mode=mode) + + def eval(self): + self.randomized = False + return super().eval() diff --git a/tsr/models/network_utils.py b/tsr/models/network_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3844f533bf3b6c9afce6de3857255ee08125b1ba --- /dev/null +++ b/tsr/models/network_utils.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass, field +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange + +from ..utils import BaseModule + + +class TriplaneUpsampleNetwork(BaseModule): + @dataclass + class Config(BaseModule.Config): + in_channels: int + out_channels: int + + cfg: Config + + def configure(self) -> None: + self.upsample = nn.ConvTranspose2d( + self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2 + ) + + def forward(self, triplanes: torch.Tensor) -> torch.Tensor: + triplanes_up = rearrange( + self.upsample( + rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3) + ), + "(B Np) Co Hp Wp -> B Np Co Hp Wp", + Np=3, + ) + return triplanes_up + + +class NeRFMLP(BaseModule): + @dataclass + class Config(BaseModule.Config): + in_channels: int + n_neurons: int + n_hidden_layers: int + activation: str = "relu" + bias: bool = True + weight_init: Optional[str] = "kaiming_uniform" + bias_init: Optional[str] = None + + cfg: Config + + def configure(self) -> None: + layers = [ + self.make_linear( + self.cfg.in_channels, + self.cfg.n_neurons, + bias=self.cfg.bias, + weight_init=self.cfg.weight_init, + bias_init=self.cfg.bias_init, + ), + self.make_activation(self.cfg.activation), + ] + for i in range(self.cfg.n_hidden_layers - 1): + layers += [ + self.make_linear( + self.cfg.n_neurons, + self.cfg.n_neurons, + bias=self.cfg.bias, + weight_init=self.cfg.weight_init, + bias_init=self.cfg.bias_init, + ), + self.make_activation(self.cfg.activation), + ] + layers += [ + self.make_linear( + self.cfg.n_neurons, + 4, # density 1 + features 3 + bias=self.cfg.bias, + weight_init=self.cfg.weight_init, + bias_init=self.cfg.bias_init, + ) + ] + self.layers = nn.Sequential(*layers) + + def make_linear( + self, + dim_in, + dim_out, + bias=True, + weight_init=None, + bias_init=None, + ): + layer = nn.Linear(dim_in, dim_out, bias=bias) + + if weight_init is None: + pass + elif weight_init == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu") + else: + raise NotImplementedError + + if bias: + if bias_init is None: + pass + elif bias_init == "zero": + torch.nn.init.zeros_(layer.bias) + else: + raise NotImplementedError + + return layer + + def make_activation(self, activation): + if activation == "relu": + return nn.ReLU(inplace=True) + elif activation == "silu": + return nn.SiLU(inplace=True) + else: + raise NotImplementedError + + def forward(self, x): + inp_shape = x.shape[:-1] + x = x.reshape(-1, x.shape[-1]) + + features = self.layers(x) + features = features.reshape(*inp_shape, -1) + out = {"density": features[..., 0:1], "features": features[..., 1:4]} + + return out diff --git a/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc b/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a492f56dc156938d3250d77f5c182ab05e1655ea Binary files /dev/null and b/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc differ diff --git a/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc b/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06d33036639d982296542f5d441f991d3e01ddb0 Binary files /dev/null and b/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc differ diff --git a/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc b/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31191fd86c3d1668aefd27059738590dc066cf7b Binary files /dev/null and b/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc differ diff --git a/tsr/models/tokenizers/image.py b/tsr/models/tokenizers/image.py new file mode 100644 index 0000000000000000000000000000000000000000..34a092e7170c87da53363822af554c78e5f8083f --- /dev/null +++ b/tsr/models/tokenizers/image.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange +from huggingface_hub import hf_hub_download +from transformers.models.vit.modeling_vit import ViTModel + +from ...utils import BaseModule + + +class DINOSingleImageTokenizer(BaseModule): + @dataclass + class Config(BaseModule.Config): + pretrained_model_name_or_path: str = "facebook/dino-vitb16" + enable_gradient_checkpointing: bool = False + + cfg: Config + + def configure(self) -> None: + self.model: ViTModel = ViTModel( + ViTModel.config_class.from_pretrained( + hf_hub_download( + repo_id=self.cfg.pretrained_model_name_or_path, + filename="config.json", + ) + ) + ) + + if self.cfg.enable_gradient_checkpointing: + self.model.encoder.gradient_checkpointing = True + + self.register_buffer( + "image_mean", + torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1), + persistent=False, + ) + self.register_buffer( + "image_std", + torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1), + persistent=False, + ) + + def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor: + packed = False + if images.ndim == 4: + packed = True + images = images.unsqueeze(1) + + batch_size, n_input_views = images.shape[:2] + images = (images - self.image_mean) / self.image_std + out = self.model( + rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True + ) + local_features, global_features = out.last_hidden_state, out.pooler_output + local_features = local_features.permute(0, 2, 1) + local_features = rearrange( + local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size + ) + if packed: + local_features = local_features.squeeze(1) + + return local_features + + def detokenize(self, *args, **kwargs): + raise NotImplementedError diff --git a/tsr/models/tokenizers/triplane.py b/tsr/models/tokenizers/triplane.py new file mode 100644 index 0000000000000000000000000000000000000000..ecdd7fd2201c974bb70b18a90a633287b814886f --- /dev/null +++ b/tsr/models/tokenizers/triplane.py @@ -0,0 +1,45 @@ +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from ...utils import BaseModule + + +class Triplane1DTokenizer(BaseModule): + @dataclass + class Config(BaseModule.Config): + plane_size: int + num_channels: int + + cfg: Config + + def configure(self) -> None: + self.embeddings = nn.Parameter( + torch.randn( + (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size), + dtype=torch.float32, + ) + * 1 + / math.sqrt(self.cfg.num_channels) + ) + + def forward(self, batch_size: int) -> torch.Tensor: + return rearrange( + repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size), + "B Np Ct Hp Wp -> B Ct (Np Hp Wp)", + ) + + def detokenize(self, tokens: torch.Tensor) -> torch.Tensor: + batch_size, Ct, Nt = tokens.shape + assert Nt == self.cfg.plane_size**2 * 3 + assert Ct == self.cfg.num_channels + return rearrange( + tokens, + "B Ct (Np Hp Wp) -> B Np Ct Hp Wp", + Np=3, + Hp=self.cfg.plane_size, + Wp=self.cfg.plane_size, + ) diff --git a/tsr/models/transformer/__pycache__/attention.cpython-310.pyc b/tsr/models/transformer/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02bf1262175cdbe2bb9651042c247f53c2ab0e91 Binary files /dev/null and b/tsr/models/transformer/__pycache__/attention.cpython-310.pyc differ diff --git a/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc b/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c2f3d2a8c487f237c94433bd8aaafd26afb8ce0 Binary files /dev/null and b/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc differ diff --git a/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc b/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39cb242f767c41c73138dda36700ef1554bbf31f Binary files /dev/null and b/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc differ diff --git a/tsr/models/transformer/attention.py b/tsr/models/transformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..eb873231e3ad195a7fef3a2c4ef217be3056cd4e --- /dev/null +++ b/tsr/models/transformer/attention.py @@ -0,0 +1,628 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + + +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm( + num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True + ) + else: + self.group_norm = None + + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, + num_groups=cross_attention_norm_num_groups, + eps=1e-5, + affine=True, + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + linear_cls = nn.Linear + + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + self.processor = processor + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + + return tensor + + def get_attention_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + attention_mask: torch.Tensor = None, + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, + attention_mask: torch.Tensor, + target_length: int, + batch_size: int, + out_dim: int = 3, + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, + ) + padding = torch.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor + ) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert ( + self.norm_cross is not None + ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @torch.no_grad() + def fuse_projections(self, fuse=True): + is_cross_attention = self.cross_attention_dim != self.query_dim + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not is_cross_attention: + # fetch weight matrices. + concatenated_weights = torch.cat( + [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + # create a new single projection layer and copy over the weights. + self.to_qkv = self.linear_cls( + in_features, out_features, bias=False, device=device, dtype=dtype + ) + self.to_qkv.weight.copy_(concatenated_weights) + + else: + concatenated_weights = torch.cat( + [self.to_k.weight.data, self.to_v.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = self.linear_cls( + in_features, out_features, bias=False, device=device, dtype=dtype + ) + self.to_kv.weight.copy_(concatenated_weights) + + self.fused_projections = fuse + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/tsr/models/transformer/basic_transformer_block.py b/tsr/models/transformer/basic_transformer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..a55ea6189c9e1da2f86f6d0957ae30f26fefff0a --- /dev/null +++ b/tsr/models/transformer/basic_transformer_block.py @@ -0,0 +1,314 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from .attention import Attention + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + assert norm_type == "layer_norm" + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim + if not double_self_attention + else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if self.only_cross_attention + else None, + attention_mask=attention_mask, + ) + + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice) + for hid_slice in norm_hidden_states.chunk( + num_chunks, dim=self._chunk_dim + ) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to( + dtype=gate.dtype + ) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + linear_cls = nn.Linear + + self.proj = linear_cls(dim_in, dim_out * 2) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states, scale: float = 1.0): + args = () + hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2: + https://arxiv.org/abs/1606.08415. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) diff --git a/tsr/models/transformer/transformer_1d.py b/tsr/models/transformer/transformer_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..7546232c5126a6622f1fd701ba36f9d4b53b9178 --- /dev/null +++ b/tsr/models/transformer/transformer_1d.py @@ -0,0 +1,216 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...utils import BaseModule +from .basic_transformer_block import BasicTransformerBlock + + +class Transformer1D(BaseModule): + """ + A 1D Transformer model for sequence data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @dataclass + class Config(BaseModule.Config): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + in_channels: Optional[int] = None + out_channels: Optional[int] = None + num_layers: int = 1 + dropout: float = 0.0 + norm_num_groups: int = 32 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_type: str = "layer_norm" + norm_elementwise_affine: bool = True + gradient_checkpointing: bool = False + + cfg: Config + + def configure(self) -> None: + self.num_attention_heads = self.cfg.num_attention_heads + self.attention_head_dim = self.cfg.attention_head_dim + inner_dim = self.num_attention_heads * self.attention_head_dim + + linear_cls = nn.Linear + + # 2. Define input layers + self.in_channels = self.cfg.in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=self.cfg.norm_num_groups, + num_channels=self.cfg.in_channels, + eps=1e-6, + affine=True, + ) + self.proj_in = linear_cls(self.cfg.in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + self.num_attention_heads, + self.attention_head_dim, + dropout=self.cfg.dropout, + cross_attention_dim=self.cfg.cross_attention_dim, + activation_fn=self.cfg.activation_fn, + attention_bias=self.cfg.attention_bias, + only_cross_attention=self.cfg.only_cross_attention, + double_self_attention=self.cfg.double_self_attention, + upcast_attention=self.cfg.upcast_attention, + norm_type=self.cfg.norm_type, + norm_elementwise_affine=self.cfg.norm_elementwise_affine, + ) + for d in range(self.cfg.num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = ( + self.cfg.in_channels + if self.cfg.out_channels is None + else self.cfg.out_channels + ) + + self.proj_out = linear_cls(inner_dim, self.cfg.in_channels) + + self.gradient_checkpointing = self.cfg.gradient_checkpointing + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ): + """ + The [`Transformer1DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch, _, seq_len = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 1).reshape( + batch, seq_len, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, seq_len, inner_dim) + .permute(0, 2, 1) + .contiguous() + ) + + output = hidden_states + residual + + return output diff --git a/tsr/system.py b/tsr/system.py new file mode 100644 index 0000000000000000000000000000000000000000..de5ae6ef75afd3bbbe0bdfacc7258a8c51409cc5 --- /dev/null +++ b/tsr/system.py @@ -0,0 +1,203 @@ +import math +import os +from dataclasses import dataclass, field +from typing import List, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +import trimesh +from einops import rearrange +from huggingface_hub import hf_hub_download +from omegaconf import OmegaConf +from PIL import Image + +from .models.isosurface import MarchingCubeHelper +from .utils import ( + BaseModule, + ImagePreprocessor, + find_class, + get_spherical_cameras, + scale_tensor, +) + + +class TSR(BaseModule): + @dataclass + class Config(BaseModule.Config): + cond_image_size: int + + image_tokenizer_cls: str + image_tokenizer: dict + + tokenizer_cls: str + tokenizer: dict + + backbone_cls: str + backbone: dict + + post_processor_cls: str + post_processor: dict + + decoder_cls: str + decoder: dict + + renderer_cls: str + renderer: dict + + cfg: Config + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str, token=None + ): + if os.path.isdir(pretrained_model_name_or_path): + config_path = os.path.join(pretrained_model_name_or_path, config_name) + weight_path = os.path.join(pretrained_model_name_or_path, weight_name) + else: + config_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, filename=config_name, token=token + ) + weight_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, filename=weight_name, token=token + ) + + cfg = OmegaConf.load(config_path) + OmegaConf.resolve(cfg) + model = cls(cfg) + ckpt = torch.load(weight_path, map_location="cpu") + model.load_state_dict(ckpt) + return model + + def configure(self): + self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)( + self.cfg.image_tokenizer + ) + self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer) + self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone) + self.post_processor = find_class(self.cfg.post_processor_cls)( + self.cfg.post_processor + ) + self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder) + self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer) + self.image_processor = ImagePreprocessor() + self.isosurface_helper = None + + def forward( + self, + image: Union[ + PIL.Image.Image, + np.ndarray, + torch.FloatTensor, + List[PIL.Image.Image], + List[np.ndarray], + List[torch.FloatTensor], + ], + device: str, + ) -> torch.FloatTensor: + rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to( + device + ) + batch_size = rgb_cond.shape[0] + + input_image_tokens: torch.Tensor = self.image_tokenizer( + rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1), + ) + + input_image_tokens = rearrange( + input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1 + ) + + tokens: torch.Tensor = self.tokenizer(batch_size) + + tokens = self.backbone( + tokens, + encoder_hidden_states=input_image_tokens, + ) + + scene_codes = self.post_processor(self.tokenizer.detokenize(tokens)) + return scene_codes + + def render( + self, + scene_codes, + n_views: int, + elevation_deg: float = 0.0, + camera_distance: float = 1.9, + fovy_deg: float = 40.0, + height: int = 256, + width: int = 256, + return_type: str = "pil", + ): + rays_o, rays_d = get_spherical_cameras( + n_views, elevation_deg, camera_distance, fovy_deg, height, width + ) + rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device) + + def process_output(image: torch.FloatTensor): + if return_type == "pt": + return image + elif return_type == "np": + return image.detach().cpu().numpy() + elif return_type == "pil": + return Image.fromarray( + (image.detach().cpu().numpy() * 255.0).astype(np.uint8) + ) + else: + raise NotImplementedError + + images = [] + for scene_code in scene_codes: + images_ = [] + for i in range(n_views): + with torch.no_grad(): + image = self.renderer( + self.decoder, scene_code, rays_o[i], rays_d[i] + ) + images_.append(process_output(image)) + images.append(images_) + + return images + + def set_marching_cubes_resolution(self, resolution: int): + if ( + self.isosurface_helper is not None + and self.isosurface_helper.resolution == resolution + ): + return + self.isosurface_helper = MarchingCubeHelper(resolution) + + def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0): + self.set_marching_cubes_resolution(resolution) + meshes = [] + for scene_code in scene_codes: + with torch.no_grad(): + density = self.renderer.query_triplane( + self.decoder, + scale_tensor( + self.isosurface_helper.grid_vertices.to(scene_codes.device), + self.isosurface_helper.points_range, + (-self.renderer.cfg.radius, self.renderer.cfg.radius), + ), + scene_code, + )["density_act"] + v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold)) + v_pos = scale_tensor( + v_pos, + self.isosurface_helper.points_range, + (-self.renderer.cfg.radius, self.renderer.cfg.radius), + ) + with torch.no_grad(): + color = self.renderer.query_triplane( + self.decoder, + v_pos, + scene_code, + )["color"] + mesh = trimesh.Trimesh( + vertices=v_pos.cpu().numpy(), + faces=t_pos_idx.cpu().numpy(), + vertex_colors=color.cpu().numpy(), + ) + meshes.append(mesh) + return meshes diff --git a/tsr/utils.py b/tsr/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a1b59aef75d02d39b29222d300fe9241bb11444 --- /dev/null +++ b/tsr/utils.py @@ -0,0 +1,482 @@ +import importlib +import math +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import imageio +import numpy as np +import PIL.Image +import rembg +import torch +import torch.nn as nn +import torch.nn.functional as F +import trimesh +from omegaconf import DictConfig, OmegaConf +from PIL import Image + + +def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: + scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg) + return scfg + + +def find_class(cls_string): + module_string = ".".join(cls_string.split(".")[:-1]) + cls_name = cls_string.split(".")[-1] + module = importlib.import_module(module_string, package=None) + cls = getattr(module, cls_name) + return cls + + +def get_intrinsic_from_fov(fov, H, W, bs=-1): + focal_length = 0.5 * H / np.tan(0.5 * fov) + intrinsic = np.identity(3, dtype=np.float32) + intrinsic[0, 0] = focal_length + intrinsic[1, 1] = focal_length + intrinsic[0, 2] = W / 2.0 + intrinsic[1, 2] = H / 2.0 + + if bs > 0: + intrinsic = intrinsic[None].repeat(bs, axis=0) + + return torch.from_numpy(intrinsic) + + +class BaseModule(nn.Module): + @dataclass + class Config: + pass + + cfg: Config # add this to every subclass of BaseModule to enable static type checking + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self.configure(*args, **kwargs) + + def configure(self, *args, **kwargs) -> None: + raise NotImplementedError + + +class ImagePreprocessor: + def convert_and_resize( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + size: int, + ): + if isinstance(image, PIL.Image.Image): + image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0) + elif isinstance(image, np.ndarray): + if image.dtype == np.uint8: + image = torch.from_numpy(image.astype(np.float32) / 255.0) + else: + image = torch.from_numpy(image) + elif isinstance(image, torch.Tensor): + pass + + batched = image.ndim == 4 + + if not batched: + image = image[None, ...] + image = F.interpolate( + image.permute(0, 3, 1, 2), + (size, size), + mode="bilinear", + align_corners=False, + antialias=True, + ).permute(0, 2, 3, 1) + if not batched: + image = image[0] + return image + + def __call__( + self, + image: Union[ + PIL.Image.Image, + np.ndarray, + torch.FloatTensor, + List[PIL.Image.Image], + List[np.ndarray], + List[torch.FloatTensor], + ], + size: int, + ) -> Any: + if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4: + image = self.convert_and_resize(image, size) + else: + if not isinstance(image, list): + image = [image] + image = [self.convert_and_resize(im, size) for im in image] + image = torch.stack(image, dim=0) + return image + + +def rays_intersect_bbox( + rays_o: torch.Tensor, + rays_d: torch.Tensor, + radius: float, + near: float = 0.0, + valid_thresh: float = 0.01, +): + input_shape = rays_o.shape[:-1] + rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3) + rays_d_valid = torch.where( + rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d + ) + if type(radius) in [int, float]: + radius = torch.FloatTensor( + [[-radius, radius], [-radius, radius], [-radius, radius]] + ).to(rays_o.device) + radius = ( + 1.0 - 1.0e-3 + ) * radius # tighten the radius to make sure the intersection point lies in the bounding box + interx0 = (radius[..., 1] - rays_o) / rays_d_valid + interx1 = (radius[..., 0] - rays_o) / rays_d_valid + t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near) + t_far = torch.maximum(interx0, interx1).amin(dim=-1) + + # check wheter a ray intersects the bbox or not + rays_valid = t_far - t_near > valid_thresh + + t_near[torch.where(~rays_valid)] = 0.0 + t_far[torch.where(~rays_valid)] = 0.0 + + t_near = t_near.view(*input_shape, 1) + t_far = t_far.view(*input_shape, 1) + rays_valid = rays_valid.view(*input_shape) + + return t_near, t_far, rays_valid + + +def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any: + if chunk_size <= 0: + return func(*args, **kwargs) + B = None + for arg in list(args) + list(kwargs.values()): + if isinstance(arg, torch.Tensor): + B = arg.shape[0] + break + assert ( + B is not None + ), "No tensor found in args or kwargs, cannot determine batch size." + out = defaultdict(list) + out_type = None + # max(1, B) to support B == 0 + for i in range(0, max(1, B), chunk_size): + out_chunk = func( + *[ + arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for arg in args + ], + **{ + k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for k, arg in kwargs.items() + }, + ) + if out_chunk is None: + continue + out_type = type(out_chunk) + if isinstance(out_chunk, torch.Tensor): + out_chunk = {0: out_chunk} + elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): + chunk_length = len(out_chunk) + out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} + elif isinstance(out_chunk, dict): + pass + else: + print( + f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." + ) + exit(1) + for k, v in out_chunk.items(): + v = v if torch.is_grad_enabled() else v.detach() + out[k].append(v) + + if out_type is None: + return None + + out_merged: Dict[Any, Optional[torch.Tensor]] = {} + for k, v in out.items(): + if all([vv is None for vv in v]): + # allow None in return value + out_merged[k] = None + elif all([isinstance(vv, torch.Tensor) for vv in v]): + out_merged[k] = torch.cat(v, dim=0) + else: + raise TypeError( + f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" + ) + + if out_type is torch.Tensor: + return out_merged[0] + elif out_type in [tuple, list]: + return out_type([out_merged[i] for i in range(chunk_length)]) + elif out_type is dict: + return out_merged + + +ValidScale = Union[Tuple[float, float], torch.FloatTensor] + + +def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale): + if inp_scale is None: + inp_scale = (0, 1) + if tgt_scale is None: + tgt_scale = (0, 1) + if isinstance(tgt_scale, torch.FloatTensor): + assert dat.shape[-1] == tgt_scale.shape[-1] + dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) + dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] + return dat + + +def get_activation(name) -> Callable: + if name is None: + return lambda x: x + name = name.lower() + if name == "none": + return lambda x: x + elif name == "exp": + return lambda x: torch.exp(x) + elif name == "sigmoid": + return lambda x: torch.sigmoid(x) + elif name == "tanh": + return lambda x: torch.tanh(x) + elif name == "softplus": + return lambda x: F.softplus(x) + else: + try: + return getattr(F, name) + except AttributeError: + raise ValueError(f"Unknown activation function: {name}") + + +def get_ray_directions( + H: int, + W: int, + focal: Union[float, Tuple[float, float]], + principal: Optional[Tuple[float, float]] = None, + use_pixel_centers: bool = True, + normalize: bool = True, +) -> torch.FloatTensor: + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + + Inputs: + H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + pixel_center = 0.5 if use_pixel_centers else 0 + + if isinstance(focal, float): + fx, fy = focal, focal + cx, cy = W / 2, H / 2 + else: + fx, fy = focal + assert principal is not None + cx, cy = principal + + i, j = torch.meshgrid( + torch.arange(W, dtype=torch.float32) + pixel_center, + torch.arange(H, dtype=torch.float32) + pixel_center, + indexing="xy", + ) + + directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1) + + if normalize: + directions = F.normalize(directions, dim=-1) + + return directions + + +def get_rays( + directions, + c2w, + keepdim=False, + noise_scale=0.0, + normalize=False, +) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + # Rotate ray directions from camera coordinate to the world coordinate + assert directions.shape[-1] == 3 + + if directions.ndim == 2: # (N_rays, 3) + if c2w.ndim == 2: # (4, 4) + c2w = c2w[None, :, :] + assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4) + rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3) + rays_o = c2w[:, :3, 3].expand(rays_d.shape) + elif directions.ndim == 3: # (H, W, 3) + assert c2w.ndim in [2, 3] + if c2w.ndim == 2: # (4, 4) + rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum( + -1 + ) # (H, W, 3) + rays_o = c2w[None, None, :3, 3].expand(rays_d.shape) + elif c2w.ndim == 3: # (B, 4, 4) + rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( + -1 + ) # (B, H, W, 3) + rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) + elif directions.ndim == 4: # (B, H, W, 3) + assert c2w.ndim == 3 # (B, 4, 4) + rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( + -1 + ) # (B, H, W, 3) + rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) + + # add camera noise to avoid grid-like artifect + # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373 + if noise_scale > 0: + rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale + rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale + + if normalize: + rays_d = F.normalize(rays_d, dim=-1) + if not keepdim: + rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) + + return rays_o, rays_d + + +def get_spherical_cameras( + n_views: int, + elevation_deg: float, + camera_distance: float, + fovy_deg: float, + height: int, + width: int, +): + azimuth_deg = torch.linspace(0, 360.0, n_views + 1)[:n_views] + elevation_deg = torch.full_like(azimuth_deg, elevation_deg) + camera_distances = torch.full_like(elevation_deg, camera_distance) + + elevation = elevation_deg * math.pi / 180 + azimuth = azimuth_deg * math.pi / 180 + + # convert spherical coordinates to cartesian coordinates + # right hand coordinate system, x back, y right, z up + # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) + camera_positions = torch.stack( + [ + camera_distances * torch.cos(elevation) * torch.cos(azimuth), + camera_distances * torch.cos(elevation) * torch.sin(azimuth), + camera_distances * torch.sin(elevation), + ], + dim=-1, + ) + + # default scene center at origin + center = torch.zeros_like(camera_positions) + # default camera up direction as +z + up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1) + + fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180 + + lookat = F.normalize(center - camera_positions, dim=-1) + right = F.normalize(torch.cross(lookat, up), dim=-1) + up = F.normalize(torch.cross(right, lookat), dim=-1) + c2w3x4 = torch.cat( + [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], + dim=-1, + ) + c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1) + c2w[:, 3, 3] = 1.0 + + # get directions by dividing directions_unit_focal by focal length + focal_length = 0.5 * height / torch.tan(0.5 * fovy) + directions_unit_focal = get_ray_directions( + H=height, + W=width, + focal=1.0, + ) + directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1) + directions[:, :, :, :2] = ( + directions[:, :, :, :2] / focal_length[:, None, None, None] + ) + # must use normalize=True to normalize directions here + rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True) + + return rays_o, rays_d + + +def remove_background( + image: PIL.Image.Image, + rembg_session: Any = None, + force: bool = False, + **rembg_kwargs, +) -> PIL.Image.Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + return image + + +def resize_foreground( + image: PIL.Image.Image, + ratio: float, +) -> PIL.Image.Image: + image = np.array(image) + assert image.shape[-1] == 4 + alpha = np.where(image[..., 3] > 0) + y1, y2, x1, x2 = ( + alpha[0].min(), + alpha[0].max(), + alpha[1].min(), + alpha[1].max(), + ) + # crop the foreground + fg = image[y1:y2, x1:x2] + # pad to square + size = max(fg.shape[0], fg.shape[1]) + ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 + ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 + new_image = np.pad( + fg, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + + # compute padding according to the ratio + new_size = int(new_image.shape[0] / ratio) + # pad to size, double side + ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 + ph1, pw1 = new_size - size - ph0, new_size - size - pw0 + new_image = np.pad( + new_image, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + new_image = PIL.Image.fromarray(new_image) + return new_image + + +def save_video( + frames: List[PIL.Image.Image], + output_path: str, + fps: int = 30, +): + # use imageio to save video + frames = [np.array(frame) for frame in frames] + writer = imageio.get_writer(output_path, fps=fps) + for frame in frames: + writer.append_data(frame) + writer.close() + + +def to_gradio_3d_orientation(mesh): + mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0])) + # mesh.apply_scale([1, 1, -1]) + mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0])) + return mesh diff --git a/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl b/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl new file mode 100644 index 0000000000000000000000000000000000000000..88fdefc57ff16c5fd6354342b8801509d33c529e --- /dev/null +++ b/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4af160ba1274e2205d3529a7b82efdb6946c2158a78e19631ed840301055b8d6 +size 5824388