copied from EthanZyh/DiffusionText2WorldGeneration
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +243 -0
- LICENSE +201 -0
- README.md +97 -0
- RELEASE.md +7 -0
- aegis.py +131 -0
- ar_config_tokenizer.py +137 -0
- ar_configs_base_model.py +118 -0
- ar_model.py +596 -0
- ar_modules_attention.py +262 -0
- ar_modules_embedding.py +491 -0
- ar_modules_mlp.py +50 -0
- ar_modules_normalization.py +88 -0
- ar_networks.py +63 -0
- ar_tokenizer.py +322 -0
- ar_tokenizer_image_text_tokenizer.py +318 -0
- ar_tokenizer_modules.py +560 -0
- ar_tokenizer_patching.py +279 -0
- ar_tokenizer_quantizers.py +165 -0
- ar_tokenizer_text_tokenizer.py +317 -0
- ar_tokenizer_utils.py +101 -0
- ar_transformer.py +461 -0
- ar_utils_misc.py +52 -0
- attention.py +305 -0
- base_world_generation_pipeline.py +362 -0
- batch_ops.py +46 -0
- blocklist.py +219 -0
- blocks.py +545 -0
- blur_utils.py +35 -0
- categories.py +192 -0
- checkpoint.py +76 -0
- conditioner.py +323 -0
- config.json +10 -0
- config.py +166 -0
- config_base_conditioner.py +169 -0
- config_helper.py +198 -0
- convert_pixtral_ckpt.py +209 -0
- cosmos1/models/POST_TRAINING.md +23 -0
- cosmos1/models/autoregressive/README.md +427 -0
- cosmos1/models/autoregressive/__init__.py +14 -0
- cosmos1/models/autoregressive/assets/nemo/finetuned_result.mp4 +0 -0
- cosmos1/models/autoregressive/assets/v1p0/batch_inputs/0.mp4 +0 -0
- cosmos1/models/autoregressive/assets/v1p0/batch_inputs/1.mp4 +0 -0
- cosmos1/models/autoregressive/assets/v1p0/batch_inputs/2.mp4 +0 -0
- cosmos1/models/autoregressive/assets/v1p0/batch_inputs/3.mp4 +0 -0
- cosmos1/models/autoregressive/assets/v1p0/batch_inputs/4.mp4 +0 -0
- cosmos1/models/autoregressive/assets/v1p0/batch_inputs/5.mp4 +0 -0
- cosmos1/models/autoregressive/assets/v1p0/batch_inputs/6.mp4 +0 -0
- cosmos1/models/autoregressive/assets/v1p0/batch_inputs/7.mp4 +0 -0
- cosmos1/models/autoregressive/assets/v1p0/batch_inputs/8.mp4 +0 -0
- cosmos1/models/autoregressive/assets/v1p0/batch_inputs/9.mp4 +0 -0
.gitignore
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Misc
|
17 |
+
outputs/
|
18 |
+
checkpoints/*
|
19 |
+
!checkpoints/README.md
|
20 |
+
|
21 |
+
# Data types
|
22 |
+
*.jit
|
23 |
+
*.pt
|
24 |
+
*.hdr
|
25 |
+
*.webp
|
26 |
+
*.pgm
|
27 |
+
*.tiff
|
28 |
+
*.tif
|
29 |
+
*.tar
|
30 |
+
*.tar.gz
|
31 |
+
*.gz
|
32 |
+
*.pkl
|
33 |
+
*.pt
|
34 |
+
*.bin
|
35 |
+
|
36 |
+
# Other uncheckable file types
|
37 |
+
*.zip
|
38 |
+
*.exe
|
39 |
+
*.dll
|
40 |
+
*.swp
|
41 |
+
*.vscode
|
42 |
+
*.ipynb
|
43 |
+
*.DS_Store
|
44 |
+
*.pyc
|
45 |
+
*Thumbs.db
|
46 |
+
*.patch
|
47 |
+
|
48 |
+
# Credential information that should never be checked in
|
49 |
+
credentials
|
50 |
+
*.secret
|
51 |
+
|
52 |
+
# ------------------------ BELOW IS AUTO-GENERATED FOR PYTHON REPOS ------------------------
|
53 |
+
|
54 |
+
# Byte-compiled / optimized / DLL files
|
55 |
+
**/__pycache__/
|
56 |
+
*.py[cod]
|
57 |
+
*$py.class
|
58 |
+
|
59 |
+
# C extensions
|
60 |
+
*.so
|
61 |
+
|
62 |
+
# Distribution / packaging
|
63 |
+
.Python
|
64 |
+
build/
|
65 |
+
develop-eggs/
|
66 |
+
dist/
|
67 |
+
downloads/
|
68 |
+
eggs/
|
69 |
+
.eggs/
|
70 |
+
lib/
|
71 |
+
lib64/
|
72 |
+
parts/
|
73 |
+
results/
|
74 |
+
sdist/
|
75 |
+
var/
|
76 |
+
wheels/
|
77 |
+
share/python-wheels/
|
78 |
+
*.egg-info/
|
79 |
+
.installed.config
|
80 |
+
*.egg
|
81 |
+
MANIFEST
|
82 |
+
|
83 |
+
# PyInstaller
|
84 |
+
# Usually these files are written by a python script from a template
|
85 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
86 |
+
*.manifest
|
87 |
+
*.spec
|
88 |
+
|
89 |
+
# Installer logs
|
90 |
+
pip-log.txt
|
91 |
+
pip-delete-this-directory.txt
|
92 |
+
|
93 |
+
# Unit test / coverage reports
|
94 |
+
htmlcov/
|
95 |
+
.tox/
|
96 |
+
.nox/
|
97 |
+
.coverage
|
98 |
+
.coverage.*
|
99 |
+
.cache
|
100 |
+
nosetests.xml
|
101 |
+
coverage.xml
|
102 |
+
*.cover
|
103 |
+
*.py,cover
|
104 |
+
.hypothesis/
|
105 |
+
.pytest_cache/
|
106 |
+
cover/
|
107 |
+
|
108 |
+
# Translations
|
109 |
+
*.mo
|
110 |
+
*.pot
|
111 |
+
|
112 |
+
# Django stuff:
|
113 |
+
*.log
|
114 |
+
local_settings.py
|
115 |
+
db.sqlite3
|
116 |
+
db.sqlite3-journal
|
117 |
+
|
118 |
+
# Flask stuff:
|
119 |
+
instance/
|
120 |
+
.webassets-cache
|
121 |
+
|
122 |
+
# Scrapy stuff:
|
123 |
+
.scrapy
|
124 |
+
|
125 |
+
# Sphinx documentation
|
126 |
+
docs/_build/
|
127 |
+
|
128 |
+
# PyBuilder
|
129 |
+
.pybuilder/
|
130 |
+
target/
|
131 |
+
|
132 |
+
# Third party
|
133 |
+
# Jupyter Notebook
|
134 |
+
.ipynb_checkpoints
|
135 |
+
|
136 |
+
# IPython
|
137 |
+
profile_default/
|
138 |
+
ipython_config.py
|
139 |
+
|
140 |
+
# pyenv
|
141 |
+
# For a library or package, you might want to ignore these files since the code is
|
142 |
+
# intended to run in multiple environments; otherwise, check them in:
|
143 |
+
# .python-version
|
144 |
+
|
145 |
+
# pipenv
|
146 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
147 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
148 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
149 |
+
# install all needed dependencies.
|
150 |
+
#Pipfile.lock
|
151 |
+
|
152 |
+
# poetry
|
153 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
154 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
155 |
+
# commonly ignored for libraries.
|
156 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
157 |
+
#poetry.lock
|
158 |
+
|
159 |
+
# pdm
|
160 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
161 |
+
#pdm.lock
|
162 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
163 |
+
# in version control.
|
164 |
+
# https://pdm.fming.dev/#use-with-ide
|
165 |
+
.pdm.toml
|
166 |
+
|
167 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
168 |
+
__pypackages__/
|
169 |
+
|
170 |
+
# Celery stuff
|
171 |
+
celerybeat-schedule
|
172 |
+
celerybeat.pid
|
173 |
+
|
174 |
+
# SageMath parsed files
|
175 |
+
*.sage.py
|
176 |
+
|
177 |
+
# Environments
|
178 |
+
.env
|
179 |
+
.venv
|
180 |
+
env/
|
181 |
+
venv/
|
182 |
+
ENV/
|
183 |
+
env.bak/
|
184 |
+
venv.bak/
|
185 |
+
|
186 |
+
# Spyder project settings
|
187 |
+
.spyderproject
|
188 |
+
.spyproject
|
189 |
+
|
190 |
+
# Rope project settings
|
191 |
+
.ropeproject
|
192 |
+
|
193 |
+
# mkdocs documentation
|
194 |
+
/site
|
195 |
+
|
196 |
+
# mypy
|
197 |
+
.mypy_cache/
|
198 |
+
.dmypy.json
|
199 |
+
dmypy.json
|
200 |
+
|
201 |
+
# Pyre type checker
|
202 |
+
.pyre/
|
203 |
+
|
204 |
+
# pytype static type analyzer
|
205 |
+
.pytype/
|
206 |
+
|
207 |
+
# Cython debug symbols
|
208 |
+
cython_debug/
|
209 |
+
|
210 |
+
# ruff
|
211 |
+
.ruff_cache
|
212 |
+
|
213 |
+
# PyCharm
|
214 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
215 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
216 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
217 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
218 |
+
#.idea/
|
219 |
+
CLIP
|
220 |
+
.devcontainer/devcontainer.json
|
221 |
+
|
222 |
+
# Coverage
|
223 |
+
.coverage
|
224 |
+
coverage.xml
|
225 |
+
|
226 |
+
# JUnit Reports
|
227 |
+
report.xml
|
228 |
+
|
229 |
+
# CI-CD
|
230 |
+
temp/
|
231 |
+
envs.txt
|
232 |
+
manifest.json
|
233 |
+
|
234 |
+
|
235 |
+
# locks and t5 temp files
|
236 |
+
*.locks*
|
237 |
+
*.no_exist*
|
238 |
+
*models--t5*
|
239 |
+
|
240 |
+
# OneLogger
|
241 |
+
wandb/
|
242 |
+
onelogger.err
|
243 |
+
onelogger.log
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## How to Use
|
2 |
+
|
3 |
+
```python
|
4 |
+
from transformers import AutoModel
|
5 |
+
|
6 |
+
model = AutoModel.from_pretrained(
|
7 |
+
"EthanZyh/DiffusionText2WorldGeneration",
|
8 |
+
cache_dir="./cache",
|
9 |
+
trust_remote_code=True,
|
10 |
+
# turn on offloading on a low GPU memory machine:
|
11 |
+
# offload_network=True,
|
12 |
+
# offload_tokenizer=True,
|
13 |
+
# offload_text_encoder_model=True,
|
14 |
+
# offload_prompt_upsampler=True,
|
15 |
+
# offload_guardrail_models=True,
|
16 |
+
)
|
17 |
+
prompt = "Some text prompt to generate a video"
|
18 |
+
model(prompt)
|
19 |
+
```
|
20 |
+
|
21 |
+

|
22 |
+
|
23 |
+
--------------------------------------------------------------------------------
|
24 |
+
### [Website](https://www.nvidia.com/en-us/ai/cosmos/) | [HuggingFace](https://huggingface.co/collections/nvidia/cosmos-6751e884dc10e013a0a0d8e6) | [GPU-free Preview](https://build.nvidia.com/explore/discover) | [Paper](https://arxiv.org/abs/2501.03575) | [Paper Website](https://research.nvidia.com/labs/dir/cosmos1/)
|
25 |
+
|
26 |
+
[NVIDIA Cosmos](https://www.nvidia.com/cosmos/) is a developer-first world foundation model platform designed to help Physical AI developers build their Physical AI systems better and faster. Cosmos contains
|
27 |
+
|
28 |
+
1. pre-trained models, available via [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-6751e884dc10e013a0a0d8e6) under the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) that allows commercial use of the models for free
|
29 |
+
2. training scripts under the [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0), offered through [NVIDIA Nemo Framework](https://github.com/NVIDIA/NeMo) for post-training the models for various downstream Physical AI applications
|
30 |
+
|
31 |
+
Details of the platform is described in the [Cosmos paper](https://research.nvidia.com/publication/2025-01_cosmos-world-foundation-model-platform-physical-ai). Preview access is avaiable at [build.nvidia.com](https://build.nvidia.com).
|
32 |
+
|
33 |
+
## Key Features
|
34 |
+
|
35 |
+
- [Pre-trained Diffusion-based world foundation models](cosmos1/models/diffusion/README.md) for Text2World and Video2World generation where a user can generate visual simulation based on text prompts and video prompts.
|
36 |
+
- [Pre-trained Autoregressive-based world foundation models](cosmos1/models/autoregressive/README.md) for Video2World generation where a user can generate visual simulation based on video prompts and optional text prompts.
|
37 |
+
- [Video tokenizers](https://github.com/NVIDIA/Cosmos-Tokenizer) for tokenizing videos into continuous tokens (latent vectors) and discrete tokens (integers) efficiently and effectively.
|
38 |
+
- Video curation pipeline for building your own video dataset. [Coming soon]
|
39 |
+
- [Post-training scripts](cosmos1/models/POST_TRAINING.md) via NeMo Framework to post-train the pre-trained world foundation models for various Physical AI setup.
|
40 |
+
- Pre-training scripts via NeMo Framework for building your own world foundation model. [[Diffusion](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/diffusion)] [[Autoregressive](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/multimodal_autoregressive)] [[Tokenizer](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/diffusion/vae)].
|
41 |
+
|
42 |
+
## Model Family
|
43 |
+
|
44 |
+
| Model name | Description | Try it out |
|
45 |
+
|------------|----------|----------|
|
46 |
+
| [Cosmos-1.0-Diffusion-7B-Text2World](https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-7B-Text2World) | Text to visual world generation | [Inference](cosmos1/models/diffusion/README.md) |
|
47 |
+
| [Cosmos-1.0-Diffusion-14B-Text2World](https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-14B-Text2World) | Text to visual world generation | [Inference](cosmos1/models/diffusion/README.md) |
|
48 |
+
| [Cosmos-1.0-Diffusion-7B-Video2World](https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-7B-Video2World) | Video + Text based future visual world generation | [Inference](cosmos1/models/diffusion/README.md) |
|
49 |
+
| [Cosmos-1.0-Diffusion-14B-Video2World](https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-14B-Video2World) | Video + Text based future visual world generation | [Inference](cosmos1/models/diffusion/README.md) |
|
50 |
+
| [Cosmos-1.0-Autoregressive-4B](https://huggingface.co/nvidia/Cosmos-1.0-Autoregressive-4B) | Future visual world generation | [Inference](cosmos1/models/autoregressive/README.md) |
|
51 |
+
| [Cosmos-1.0-Autoregressive-12B](https://huggingface.co/nvidia/Cosmos-1.0-Autoregressive-12B) | Future visual world generation | [Inference](cosmos1/models/autoregressive/README.md) |
|
52 |
+
| [Cosmos-1.0-Autoregressive-5B-Video2World](https://huggingface.co/nvidia/Cosmos-1.0-Autoregressive-5B-Video2World) | Video + Text based future visual world generation | [Inference](cosmos1/models/autoregressive/README.md) |
|
53 |
+
| [Cosmos-1.0-Autoregressive-13B-Video2World](https://huggingface.co/nvidia/Cosmos-1.0-Autoregressive-13B-Video2World) | Video + Text based future visual world generation | [Inference](cosmos1/models/autoregressive/README.md) |
|
54 |
+
| [Cosmos-1.0-Guardrail](https://huggingface.co/nvidia/Cosmos-1.0-Guardrail) | Guardrail contains pre-Guard and post-Guard for safe use | Embedded in model inference scripts |
|
55 |
+
|
56 |
+
## Example Usage
|
57 |
+
|
58 |
+
### Inference
|
59 |
+
|
60 |
+
Follow the [Cosmos Installation Guide](INSTALL.md) to setup the docker. For inference with the pretrained models, please refer to [Cosmos Diffusion Inference](cosmos1/models/diffusion/README.md) and [Cosmos Autoregressive Inference](cosmos1/models/autoregressive/README.md).
|
61 |
+
|
62 |
+
The code snippet below provides a gist of the inference usage.
|
63 |
+
|
64 |
+
```bash
|
65 |
+
PROMPT="A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. \
|
66 |
+
The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. \
|
67 |
+
A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, \
|
68 |
+
suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. \
|
69 |
+
The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of \
|
70 |
+
field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
|
71 |
+
|
72 |
+
# Example using 7B model
|
73 |
+
PYTHONPATH=$(pwd) python cosmos1/models/diffusion/inference/text2world.py \
|
74 |
+
--checkpoint_dir checkpoints \
|
75 |
+
--diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Text2World \
|
76 |
+
--prompt "$PROMPT" \
|
77 |
+
--offload_prompt_upsampler \
|
78 |
+
--video_save_name Cosmos-1.0-Diffusion-7B-Text2World
|
79 |
+
```
|
80 |
+
|
81 |
+
<video src="https://github.com/user-attachments/assets/db7bebfe-5314-40a6-b045-4f6ce0a87f2a">
|
82 |
+
Your browser does not support the video tag.
|
83 |
+
</video>
|
84 |
+
|
85 |
+
We also offer [multi-GPU inference](cosmos1/models/diffusion/nemo/inference/README.md) support for Diffusion Text2World WFM models through NeMo Framework.
|
86 |
+
|
87 |
+
### Post-training
|
88 |
+
|
89 |
+
NeMo Framework provides GPU accelerated post-training with general post-training for both [diffusion](cosmos1/models/diffusion/nemo/post_training/README.md) and [autoregressive](cosmos1/models/autoregressive/nemo/post_training/README.md) models, with other types of post-training coming soon.
|
90 |
+
|
91 |
+
## License and Contact
|
92 |
+
|
93 |
+
This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use.
|
94 |
+
|
95 |
+
NVIDIA Cosmos source code is released under the [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0).
|
96 |
+
|
97 |
+
NVIDIA Cosmos models are released under the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). For a custom license, please contact [cosmos-license@nvidia.com](mailto:cosmos-license@nvidia.com).
|
RELEASE.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Release Cadence
|
2 |
+
|
3 |
+
|
4 |
+
| Version | Description | Date |
|
5 |
+
|------------|----------|----------|
|
6 |
+
| [v1.0](release_notes/v0p1.md) | Initial diffusion and autoregressive WFMs release | 2025-01-06 |
|
7 |
+
| [v0.1](release_notes/v0p1.md) | Initial tokenizer release | 2024-11-06 |
|
aegis.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
from .log import log
|
19 |
+
import torch
|
20 |
+
from peft import PeftModel
|
21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
22 |
+
|
23 |
+
from .categories import UNSAFE_CATEGORIES
|
24 |
+
from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
|
25 |
+
from .misc import misc, Color, timer
|
26 |
+
|
27 |
+
SAFE = Color.green("SAFE")
|
28 |
+
UNSAFE = Color.red("UNSAFE")
|
29 |
+
|
30 |
+
DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/aegis"
|
31 |
+
|
32 |
+
|
33 |
+
class Aegis(ContentSafetyGuardrail):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR,
|
37 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
38 |
+
) -> None:
|
39 |
+
self.checkpoint_dir = checkpoint_dir
|
40 |
+
self.device = device
|
41 |
+
self.dtype = torch.bfloat16
|
42 |
+
base_model_id = "meta-llama/LlamaGuard-7b"
|
43 |
+
aegis_adapter = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0"
|
44 |
+
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, cache_dir=self.checkpoint_dir)
|
45 |
+
self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, cache_dir=self.checkpoint_dir)
|
46 |
+
self.model = PeftModel.from_pretrained(base_model, aegis_adapter, cache_dir=self.checkpoint_dir)
|
47 |
+
self.model.to(self.device, dtype=self.dtype).eval()
|
48 |
+
|
49 |
+
def get_moderation_prompt(self, user_prompt: str) -> str:
|
50 |
+
"""Create the moderation prompt for the Aegis model."""
|
51 |
+
full_prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
|
52 |
+
|
53 |
+
<BEGIN UNSAFE CONTENT CATEGORIES>
|
54 |
+
{"".join([f"{category['instruction']}" for category in UNSAFE_CATEGORIES])}
|
55 |
+
<END UNSAFE CONTENT CATEGORIES>
|
56 |
+
|
57 |
+
<BEGIN CONVERSATION>
|
58 |
+
|
59 |
+
"""
|
60 |
+
full_prompt += f"User: {user_prompt}\n\n"
|
61 |
+
full_prompt += """<END CONVERSATION>
|
62 |
+
|
63 |
+
Provide your safety assessment for User in the above conversation:
|
64 |
+
- First line must read 'safe' or 'unsafe'.
|
65 |
+
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
|
66 |
+
|
67 |
+
return full_prompt
|
68 |
+
|
69 |
+
def get_aegis_block_message(self, moderation_output: str) -> str:
|
70 |
+
"""Extract the blocked category and reason from the Aegis model output."""
|
71 |
+
block_msg = "Prompt blocked by Aegis."
|
72 |
+
try:
|
73 |
+
# Extract blocked category and reason
|
74 |
+
if len(moderation_output.split("\n")) >= 2:
|
75 |
+
moderation_output_list = moderation_output.split("\n")
|
76 |
+
blocked_category = int(moderation_output_list[1].lstrip("O")) - 1
|
77 |
+
if 0 <= blocked_category < len(UNSAFE_CATEGORIES):
|
78 |
+
blocked_reason = UNSAFE_CATEGORIES[blocked_category]["blocked_reason"]
|
79 |
+
blocked_category_name = UNSAFE_CATEGORIES[blocked_category]["category"]
|
80 |
+
block_msg = f"{blocked_category_name}: {blocked_reason}"
|
81 |
+
except Exception as e:
|
82 |
+
log.warning(f"Unable to extract blocked category and reason from Aegis output: {e}")
|
83 |
+
return block_msg
|
84 |
+
|
85 |
+
def filter_aegis_output(self, prompt: str) -> tuple[bool, str]:
|
86 |
+
"""Filter the Aegis model output and return the safety status and message."""
|
87 |
+
full_prompt = self.get_moderation_prompt(prompt)
|
88 |
+
inputs = self.tokenizer([full_prompt], add_special_tokens=False, return_tensors="pt").to(self.device)
|
89 |
+
output = self.model.generate(**inputs, max_new_tokens=100, pad_token_id=self.tokenizer.eos_token_id)
|
90 |
+
prompt_len = inputs["input_ids"].shape[-1]
|
91 |
+
moderation_output = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
|
92 |
+
|
93 |
+
if "unsafe" in moderation_output.lower():
|
94 |
+
block_msg = self.get_aegis_block_message(moderation_output)
|
95 |
+
return False, block_msg
|
96 |
+
else:
|
97 |
+
return True, ""
|
98 |
+
|
99 |
+
def is_safe(self, prompt: str) -> tuple[bool, str]:
|
100 |
+
"""Check if the input prompt is safe according to the Aegis model."""
|
101 |
+
try:
|
102 |
+
return self.filter_aegis_output(prompt)
|
103 |
+
except Exception as e:
|
104 |
+
log.error(f"Unexpected error occurred when running Aegis guardrail: {e}")
|
105 |
+
return True, "Unexpected error occurred when running Aegis guardrail."
|
106 |
+
|
107 |
+
|
108 |
+
def parse_args():
|
109 |
+
parser = argparse.ArgumentParser()
|
110 |
+
parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
|
111 |
+
parser.add_argument(
|
112 |
+
"--checkpoint_dir",
|
113 |
+
type=str,
|
114 |
+
help="Path to the Aegis checkpoint folder",
|
115 |
+
default=DEFAULT_CHECKPOINT_DIR,
|
116 |
+
)
|
117 |
+
return parser.parse_args()
|
118 |
+
|
119 |
+
|
120 |
+
def main(args):
|
121 |
+
aegis = Aegis(checkpoint_dir=args.checkpoint_dir)
|
122 |
+
runner = GuardrailRunner(safety_models=[aegis])
|
123 |
+
with timer("aegis safety check"):
|
124 |
+
safety, message = runner.run_safety_check(args.prompt)
|
125 |
+
log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
|
126 |
+
log.info(f"Message: {message}") if not safety else None
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
args = parse_args()
|
131 |
+
main(args)
|
ar_config_tokenizer.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Optional
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
|
20 |
+
from .discrete_video import DiscreteVideoFSQStateDictTokenizer
|
21 |
+
from .ar_networks import CausalDiscreteVideoTokenizer
|
22 |
+
from .lazy_config_init import LazyCall as L
|
23 |
+
from .lazy_config_init import LazyDict
|
24 |
+
|
25 |
+
|
26 |
+
def create_discrete_video_fsq_tokenizer_state_dict_config(
|
27 |
+
ckpt_path, pixel_chunk_duration=33, compression_ratio=[8, 16, 16]
|
28 |
+
) -> LazyDict:
|
29 |
+
CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)(
|
30 |
+
# The new causal discrete tokenizer, that is at least 2x more efficient in memory and runtime.
|
31 |
+
# - It relies on fully 3D discrete wavelet transform
|
32 |
+
# - Uses a layer norm instead of a group norm
|
33 |
+
# - Factorizes full convolutions into spatial and temporal convolutions
|
34 |
+
# - Factorizes full attention into spatial and temporal attention
|
35 |
+
# - Strictly causal, with flexible temporal length at inference.
|
36 |
+
attn_resolutions=[32],
|
37 |
+
channels=128,
|
38 |
+
channels_mult=[2, 4, 4],
|
39 |
+
dropout=0.0,
|
40 |
+
in_channels=3,
|
41 |
+
num_res_blocks=2,
|
42 |
+
out_channels=3,
|
43 |
+
resolution=1024,
|
44 |
+
patch_size=4,
|
45 |
+
patch_method="haar",
|
46 |
+
z_channels=16,
|
47 |
+
z_factor=1,
|
48 |
+
num_groups=1,
|
49 |
+
legacy_mode=False,
|
50 |
+
spatial_compression=16,
|
51 |
+
temporal_compression=8,
|
52 |
+
embedding_dim=6,
|
53 |
+
levels=[8, 8, 8, 5, 5, 5],
|
54 |
+
name="CausalDiscreteFactorizedVideoTokenizer",
|
55 |
+
)
|
56 |
+
|
57 |
+
return L(DiscreteVideoFSQStateDictTokenizer)(
|
58 |
+
enc_fp=ckpt_path.replace("ema.jit", "encoder.jit"),
|
59 |
+
dec_fp=ckpt_path.replace("ema.jit", "decoder.jit"),
|
60 |
+
tokenizer_module=CausalDiscreteFactorizedVideoTokenizerConfig,
|
61 |
+
name="discrete_video_fsq",
|
62 |
+
latent_ch=6,
|
63 |
+
is_bf16=True,
|
64 |
+
pixel_chunk_duration=pixel_chunk_duration,
|
65 |
+
latent_chunk_duration=1 + (pixel_chunk_duration - 1) // compression_ratio[0],
|
66 |
+
max_enc_batch_size=8,
|
67 |
+
max_dec_batch_size=4,
|
68 |
+
levels=[8, 8, 8, 5, 5, 5],
|
69 |
+
compression_ratio=compression_ratio,
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
@attrs.define(slots=False)
|
74 |
+
class TextTokenizerConfig:
|
75 |
+
"""
|
76 |
+
Text tokenizer config
|
77 |
+
|
78 |
+
Args:
|
79 |
+
config: Config file to define the text tokenizer class.
|
80 |
+
data_key (str): The input key from data_dict that will be passed to the text tokenizer.
|
81 |
+
tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
|
82 |
+
tokenizer_offset (int): Offset that is added to the tokens.
|
83 |
+
vocab_size (int): Vocabulary size of the tokenizer.
|
84 |
+
"""
|
85 |
+
|
86 |
+
config: LazyDict
|
87 |
+
data_key: str = ""
|
88 |
+
tokenize_here: bool = False
|
89 |
+
tokenizer_offset: int = 0
|
90 |
+
vocab_size: int = 0
|
91 |
+
|
92 |
+
|
93 |
+
@attrs.define(slots=False)
|
94 |
+
class VideoTokenizerConfig:
|
95 |
+
"""
|
96 |
+
Video tokenizer config
|
97 |
+
|
98 |
+
Args:
|
99 |
+
config: Config file to define the video tokenizer class.
|
100 |
+
data_key (str): The input key from data_dict that will be passed to the video tokenizer.
|
101 |
+
tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
|
102 |
+
tokenizer_offset (int): Offset that is added to the tokens. In case of joint text-video tokenizers, we
|
103 |
+
add an offset to make sure that video tokens and text tokens don't overlap.
|
104 |
+
vocab_size (int): Vocabulary size of the tokenizer.
|
105 |
+
max_seq_len (int): Maximum token length for an input video.
|
106 |
+
"""
|
107 |
+
|
108 |
+
config: LazyDict
|
109 |
+
data_key: str = ""
|
110 |
+
tokenize_here: bool = True
|
111 |
+
tokenizer_offset: int = 0
|
112 |
+
vocab_size: int = 0
|
113 |
+
max_seq_len: int = -1
|
114 |
+
|
115 |
+
|
116 |
+
@attrs.define(slots=False)
|
117 |
+
class TokenizerConfig:
|
118 |
+
"""
|
119 |
+
Joint tokenizer config
|
120 |
+
|
121 |
+
Args:
|
122 |
+
text_tokenizer (TextTokenizerConfig): Text tokenizer config file
|
123 |
+
class_tokenizer (ClassTokenizerConfig): Class tokenizer config file
|
124 |
+
video_tokenizer (VideoTokenizerConfig): Video tokenizer config file
|
125 |
+
image_tokenizer (ImageTokenizerConfig): Image tokenizer config file
|
126 |
+
seq_len (int): Final token sequence length
|
127 |
+
training_type (str): Type of training we use. Supports ["text_only", "text_to_video", "class_to_image", "image_text_interleaved"]
|
128 |
+
add_special_tokens (bool): Whether to add special tokens to the output tokens
|
129 |
+
pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
|
130 |
+
"""
|
131 |
+
|
132 |
+
text_tokenizer: Optional[TextTokenizerConfig] = None
|
133 |
+
video_tokenizer: Optional[VideoTokenizerConfig] = None
|
134 |
+
seq_len: int = 4096
|
135 |
+
training_type: str = None
|
136 |
+
add_special_tokens: bool = True
|
137 |
+
pad_to_multiple_of: Optional[int] = 64
|
ar_configs_base_model.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Optional
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
|
20 |
+
from .ar_config_tokenizer import TokenizerConfig
|
21 |
+
|
22 |
+
|
23 |
+
@attrs.define
|
24 |
+
class ModelConfig:
|
25 |
+
"""
|
26 |
+
A class to hold model configuration arguments.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
dim (int): The dimensionality of the input and output of each transformer block.
|
30 |
+
n_layers (int): Number of layers in the transformer.
|
31 |
+
n_heads (int): Number of attention heads.
|
32 |
+
n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to
|
33 |
+
`num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention.
|
34 |
+
head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads.
|
35 |
+
vocab_size (int): Vocabulary size.
|
36 |
+
ffn_hidden_size (int): Hidden size for feedforward network.
|
37 |
+
norm_eps (float): Epsilon value for normalization.
|
38 |
+
rope_theta (float): Theta value for rotary positional embeddings.
|
39 |
+
apply_abs_pos_emb (bool): Whether to apply absolute position embeddings.
|
40 |
+
max_batch_size (int): Maximum batch size for inference.
|
41 |
+
max_seq_len (int): Maximum sequence length for input text.
|
42 |
+
fuse_qkv (bool): Whether to fuse QKV in attention. Defaults to True.
|
43 |
+
causal_mask (bool): Whether to use causal mask. Defaults to True.
|
44 |
+
norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm".
|
45 |
+
precision (str): Data type for the model.
|
46 |
+
use_qk_normalization (bool): Whether to enable QK normalization.
|
47 |
+
ckpt_dir (str): Checkpoint directory.
|
48 |
+
ckpt_path (str): Checkpoint path.
|
49 |
+
apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension).
|
50 |
+
yarn_scale (Optional[float]): Scale factor for YaRN.
|
51 |
+
yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code)
|
52 |
+
yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code)
|
53 |
+
original_seq_len (Optional[int]): Original sequence length.
|
54 |
+
vision_encoder (Optional[str]): Vision encoder name.
|
55 |
+
mm_projector (Optional[str]): Multi-modal projector name.
|
56 |
+
vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4.
|
57 |
+
rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "3D".
|
58 |
+
pytorch_rope_version (Optional[str]): Version of the PyTorch RoPE implementation. Choices: "v1", "v2".
|
59 |
+
original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension.
|
60 |
+
pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
|
61 |
+
vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3.
|
62 |
+
insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
|
63 |
+
insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
|
64 |
+
context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
|
65 |
+
num_video_frames (Optional[int]): Number of video frames.
|
66 |
+
video_height (Optional[int]): Raw video pixel height dimension.
|
67 |
+
video_width (Optional[int]): Raw video pixel width dimension.
|
68 |
+
video_latent_shape (Optional[list]): Video tokenizer output dimension, in (T,H,W).
|
69 |
+
"""
|
70 |
+
|
71 |
+
dim: int = attrs.field(default=4096)
|
72 |
+
n_layers: int = attrs.field(default=32)
|
73 |
+
n_heads: int = attrs.field(default=32)
|
74 |
+
n_kv_heads: Optional[int] = attrs.field(default=8)
|
75 |
+
head_dim: Optional[int] = attrs.field(default=None)
|
76 |
+
vocab_size: int = attrs.field(default=128256)
|
77 |
+
ffn_hidden_size: int = attrs.field(default=14336)
|
78 |
+
norm_eps: float = attrs.field(default=1e-5)
|
79 |
+
rope_theta: float = attrs.field(default=500000)
|
80 |
+
apply_abs_pos_emb: bool = attrs.field(default=False)
|
81 |
+
max_batch_size: int = attrs.field(default=1)
|
82 |
+
max_seq_len: int = attrs.field(default=8192)
|
83 |
+
fuse_qkv: bool = attrs.field(default=False)
|
84 |
+
causal_mask: bool = attrs.field(default=True)
|
85 |
+
norm_type: str = attrs.field(default="rmsnorm")
|
86 |
+
precision: str = attrs.field(default="bfloat16")
|
87 |
+
use_qk_normalization: bool = False
|
88 |
+
tokenizer: Optional[TokenizerConfig] = None
|
89 |
+
ckpt_dir: Optional[str] = attrs.field(default=None)
|
90 |
+
ckpt_path: Optional[str] = attrs.field(
|
91 |
+
default=None
|
92 |
+
) # If not None, load the model from this path instead of ckpt_dir
|
93 |
+
apply_yarn: Optional[bool] = attrs.field(default=False)
|
94 |
+
yarn_scale: Optional[float] = attrs.field(default=None)
|
95 |
+
yarn_beta_fast: Optional[int] = attrs.field(default=None)
|
96 |
+
yarn_beta_slow: Optional[int] = attrs.field(default=None)
|
97 |
+
original_seq_len: Optional[int] = attrs.field(default=None)
|
98 |
+
vision_encoder: Optional[str] = attrs.field(default=None)
|
99 |
+
vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
|
100 |
+
mm_projector: Optional[str] = attrs.field(default=None)
|
101 |
+
rope_dim: Optional[str] = attrs.field(default="1D")
|
102 |
+
pytorch_rope_version: Optional[str] = attrs.field(default="v2")
|
103 |
+
original_latent_shape: Optional[list] = None
|
104 |
+
pad_to_multiple_of: Optional[int] = None
|
105 |
+
vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
|
106 |
+
insert_cross_attn: bool = False
|
107 |
+
insert_cross_attn_every_k_layers: int = 1
|
108 |
+
context_dim: Optional[int] = attrs.field(default=1024)
|
109 |
+
# For video training
|
110 |
+
num_video_frames: Optional[int] = None
|
111 |
+
# Raw video pixel dimension
|
112 |
+
video_height: Optional[int] = None
|
113 |
+
video_width: Optional[int] = None
|
114 |
+
# Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact
|
115 |
+
video_latent_shape: Optional[list] = None
|
116 |
+
|
117 |
+
def __getitem__(self, item):
|
118 |
+
return getattr(self, item)
|
ar_model.py
ADDED
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
import time
|
19 |
+
from pathlib import Path
|
20 |
+
from typing import Any, Dict, List, Optional, Set
|
21 |
+
|
22 |
+
from .log import log
|
23 |
+
import torch
|
24 |
+
from safetensors.torch import load_file
|
25 |
+
from torch.nn.modules.module import _IncompatibleKeys
|
26 |
+
|
27 |
+
from .ar_configs_base_model import ModelConfig
|
28 |
+
from .ar_config_tokenizer import TokenizerConfig
|
29 |
+
from .mm_projector import MultimodalProjector
|
30 |
+
from .ar_transformer import Transformer
|
31 |
+
from .vit import VisionTransformer, get_vit_config
|
32 |
+
from .ar_tokenizer import DiscreteMultimodalTokenizer, update_vocab_size
|
33 |
+
from .checkpoint import (
|
34 |
+
get_partial_state_dict,
|
35 |
+
process_state_dict,
|
36 |
+
substrings_to_ignore,
|
37 |
+
)
|
38 |
+
from .sampling import decode_n_tokens, decode_one_token, prefill
|
39 |
+
from .misc import misc, Color, timer
|
40 |
+
|
41 |
+
|
42 |
+
class AutoRegressiveModel(torch.nn.Module):
|
43 |
+
"""
|
44 |
+
A class to build and use a AutoRegressiveModel model for text generation.
|
45 |
+
|
46 |
+
Methods:
|
47 |
+
build: Build a AutoRegressiveModel instance by initializing and loading a model checkpoint.
|
48 |
+
generate: Generate text sequences based on provided prompts using the language generation model.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
model: Transformer = None,
|
54 |
+
tokenizer: DiscreteMultimodalTokenizer = None,
|
55 |
+
config: ModelConfig = None,
|
56 |
+
vision_encoder: VisionTransformer = None,
|
57 |
+
mm_projector: MultimodalProjector = None,
|
58 |
+
):
|
59 |
+
"""
|
60 |
+
Initialize the AutoRegressiveModel instance with a model and tokenizer.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
model (Transformer): The Transformer model for text generation.
|
64 |
+
tokenizer (Tokenizer): The tokenizer for encoding and decoding text.
|
65 |
+
config (Config): The configuration for the AutoRegressiveModel model.
|
66 |
+
vision_encoder (VisionTransformer): The vision encoder for the AutoRegressiveModel model.
|
67 |
+
mm_projector (MultimodalProjector): The multi-modal projector for the AutoRegressiveModel model.
|
68 |
+
"""
|
69 |
+
super().__init__()
|
70 |
+
self.model = model
|
71 |
+
self.tokenizer = tokenizer
|
72 |
+
self.config = config
|
73 |
+
|
74 |
+
self.vision_encoder = vision_encoder
|
75 |
+
self.mm_projector = mm_projector
|
76 |
+
|
77 |
+
@property
|
78 |
+
def precision(self):
|
79 |
+
return self.model.precision
|
80 |
+
|
81 |
+
def get_num_params(
|
82 |
+
self,
|
83 |
+
) -> int:
|
84 |
+
"""
|
85 |
+
Return the number of parameters in the model.
|
86 |
+
"""
|
87 |
+
n_params = sum(p.numel() for p in self.parameters())
|
88 |
+
return n_params
|
89 |
+
|
90 |
+
def load_ar_model(
|
91 |
+
self,
|
92 |
+
tokenizer_config,
|
93 |
+
):
|
94 |
+
"""
|
95 |
+
Load the AR model.
|
96 |
+
"""
|
97 |
+
model_config = self.config
|
98 |
+
ckpt_path = model_config.ckpt_path
|
99 |
+
with timer(f"loading checkpoint from {ckpt_path}"):
|
100 |
+
if ckpt_path.endswith("safetensors"):
|
101 |
+
# Load with safetensors API
|
102 |
+
checkpoint = load_file(ckpt_path, device="cpu")
|
103 |
+
else:
|
104 |
+
# The pytorch version
|
105 |
+
checkpoint = torch.load(
|
106 |
+
ckpt_path,
|
107 |
+
map_location="cpu",
|
108 |
+
mmap=True, # load the checkpoint in memory-mapped mode
|
109 |
+
weights_only=True,
|
110 |
+
)
|
111 |
+
llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint
|
112 |
+
orig_precision = torch.get_default_dtype()
|
113 |
+
precision = getattr(torch, model_config.precision)
|
114 |
+
torch.set_default_dtype(precision)
|
115 |
+
log.debug(f"Setting torch default dtype to {precision}")
|
116 |
+
|
117 |
+
model = Transformer(
|
118 |
+
params=model_config,
|
119 |
+
tokenizer_config=tokenizer_config,
|
120 |
+
)
|
121 |
+
log.debug(
|
122 |
+
f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size}"
|
123 |
+
)
|
124 |
+
vocab_size = update_vocab_size(
|
125 |
+
existing_vocab_size=0,
|
126 |
+
to_be_added_vocab_size=tokenizer_config.video_tokenizer.vocab_size,
|
127 |
+
training_type=tokenizer_config.training_type,
|
128 |
+
add_special_tokens=False,
|
129 |
+
)
|
130 |
+
log.debug(
|
131 |
+
f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size} vocab_size {vocab_size}"
|
132 |
+
)
|
133 |
+
# Perform vocab expansion
|
134 |
+
if vocab_size > model.vocab_size:
|
135 |
+
log.debug(f"Expanding vocab size to {vocab_size}")
|
136 |
+
# For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer,
|
137 |
+
expand_output_layer = not (tokenizer_config.training_type == "text_to_video")
|
138 |
+
model.expand_vocab(
|
139 |
+
vocab_size,
|
140 |
+
init_method="gaussian",
|
141 |
+
expand_output_layer=expand_output_layer,
|
142 |
+
)
|
143 |
+
# Remove the "model." prefix in the state_dict
|
144 |
+
llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
|
145 |
+
with timer("loading state_dict into model"):
|
146 |
+
missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True)
|
147 |
+
# Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
|
148 |
+
missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
|
149 |
+
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
|
150 |
+
|
151 |
+
self.model = model.to(precision).to("cuda")
|
152 |
+
torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value
|
153 |
+
|
154 |
+
def load_tokenizer(self, tokenizer_config):
|
155 |
+
"""
|
156 |
+
Load the tokenizer.
|
157 |
+
"""
|
158 |
+
self.tokenizer = DiscreteMultimodalTokenizer(tokenizer_config)
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def build(
|
162 |
+
model_config: ModelConfig = ModelConfig(),
|
163 |
+
tokenizer_config: TokenizerConfig = None,
|
164 |
+
) -> "AutoRegressiveModel":
|
165 |
+
"""
|
166 |
+
Build a AutoRegressiveModel instance by initializing and loading a model checkpoint.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
model_config (ModelConfig, optional): The model configuration for the AutoRegressiveModel instance. Defaults to ModelConfig().
|
170 |
+
tokenizer_config (TokenizerConfig, optional): The tokenizer configuration for the AutoRegressiveModel instance. Defaults to None.
|
171 |
+
download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True.
|
172 |
+
Returns:
|
173 |
+
AutoRegressiveModel: An instance of the AutoRegressiveModel class with the loaded model and tokenizer.
|
174 |
+
|
175 |
+
Raises:
|
176 |
+
AssertionError: If there are no checkpoint files in the specified directory.
|
177 |
+
|
178 |
+
Note:
|
179 |
+
This method sets the device to CUDA and loads the pre-trained model and tokenizer.
|
180 |
+
"""
|
181 |
+
# Initialize model configuration parameters
|
182 |
+
config_params = {}
|
183 |
+
|
184 |
+
# Load checkpoint and model parameters
|
185 |
+
|
186 |
+
if model_config.ckpt_path is None:
|
187 |
+
# If ckpt_path is not provided, we assume the model checkpoint is saved in the ckpt_dir
|
188 |
+
ckpt_dir = model_config.ckpt_dir
|
189 |
+
|
190 |
+
# We prioritize safetensors version over the pytorch version, since the former is
|
191 |
+
# much faster for checkpoint loading.
|
192 |
+
checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors"))
|
193 |
+
if len(checkpoints) == 0:
|
194 |
+
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
195 |
+
|
196 |
+
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
197 |
+
assert (
|
198 |
+
len(checkpoints) == 1
|
199 |
+
), f"multiple checkpoint files found in {ckpt_dir} (currently only one is supported)"
|
200 |
+
ckpt_path = str(checkpoints[0]) # Assuming single checkpoint for non-parallel case
|
201 |
+
|
202 |
+
if os.path.exists(Path(ckpt_dir) / "config.json"):
|
203 |
+
with open(Path(ckpt_dir) / "config.json", "r") as f:
|
204 |
+
config_params = json.loads(f.read())
|
205 |
+
else:
|
206 |
+
log.info(
|
207 |
+
f"No params.json found in the checkpoint directory ({ckpt_dir}). " f"Using default model config."
|
208 |
+
)
|
209 |
+
|
210 |
+
else:
|
211 |
+
# If ckpt_path is provided, we load the model from the specified path,
|
212 |
+
# and use the default model configuration
|
213 |
+
ckpt_path = model_config.ckpt_path
|
214 |
+
|
215 |
+
for key, value in config_params.items():
|
216 |
+
if hasattr(model_config, key):
|
217 |
+
# Override the default model configuration with the parameters from the checkpoint
|
218 |
+
setattr(model_config, key, value)
|
219 |
+
|
220 |
+
with timer(f"loading checkpoint from {ckpt_path}"):
|
221 |
+
if ckpt_path.endswith("safetensors"):
|
222 |
+
# Load with safetensors API
|
223 |
+
checkpoint = load_file(ckpt_path, device="cpu")
|
224 |
+
else:
|
225 |
+
# The pytorch version
|
226 |
+
checkpoint = torch.load(
|
227 |
+
ckpt_path,
|
228 |
+
map_location="cpu",
|
229 |
+
mmap=True, # load the checkpoint in memory-mapped mode
|
230 |
+
weights_only=True,
|
231 |
+
)
|
232 |
+
llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint
|
233 |
+
|
234 |
+
if model_config.vision_encoder is not None:
|
235 |
+
# Take the LLM weights (starting with "model.") from the VLM checkpoint
|
236 |
+
llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.")
|
237 |
+
if model_config.vision_encoder is not None:
|
238 |
+
# For vanilla VLM ckpt before fine-tuning, `checkpoint['model']` only contains LLM weights, and `checkpoint['vision_encoder']`
|
239 |
+
# and `checkpoint['mm_projector']` are both for those weights
|
240 |
+
# For fine-tuned VLM ckpt, `checkpoint['model']` contains all LLM, mm_projector and vision_encoder weights
|
241 |
+
if "vision_encoder" in checkpoint:
|
242 |
+
log.debug("Using pretrained vision_encoder")
|
243 |
+
vit_checkpoint = checkpoint["vision_encoder"]
|
244 |
+
else:
|
245 |
+
log.debug("Using fine-tuned vision_encoder")
|
246 |
+
vit_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="vision_encoder.")
|
247 |
+
vit_checkpoint = process_state_dict(vit_checkpoint, prefix_to_remove="vision_encoder.")
|
248 |
+
if "mm_projector" in checkpoint:
|
249 |
+
log.debug("Using pretrained mm_projector")
|
250 |
+
projector_checkpoint = checkpoint["mm_projector"]
|
251 |
+
else:
|
252 |
+
log.debug("Using fine-tuned mm_projector")
|
253 |
+
projector_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="mm_projector.")
|
254 |
+
projector_checkpoint = process_state_dict(projector_checkpoint, prefix_to_remove="mm_projector.")
|
255 |
+
assert (
|
256 |
+
len(vit_checkpoint) > 0 and len(projector_checkpoint) > 0
|
257 |
+
), "vit_checkpoint and projector_checkpoint cannot be empty. We do not support random initialization for vision_encoder and mm_projector."
|
258 |
+
|
259 |
+
tokenizer = DiscreteMultimodalTokenizer(tokenizer_config)
|
260 |
+
orig_precision = torch.get_default_dtype()
|
261 |
+
precision = getattr(torch, model_config.precision)
|
262 |
+
torch.set_default_dtype(precision)
|
263 |
+
log.debug(f"Setting torch default dtype to {precision}")
|
264 |
+
|
265 |
+
model = Transformer(
|
266 |
+
params=model_config,
|
267 |
+
tokenizer_config=tokenizer_config,
|
268 |
+
)
|
269 |
+
model_kwargs = {}
|
270 |
+
|
271 |
+
if model_config.vision_encoder is not None:
|
272 |
+
assert model_config.mm_projector is not None, "mm_projector must be provided if vision_encoder is provided."
|
273 |
+
vit_config = get_vit_config(model_config.vision_encoder)
|
274 |
+
vision_encoder = VisionTransformer.build(
|
275 |
+
vit_config,
|
276 |
+
)
|
277 |
+
|
278 |
+
mm_projector = MultimodalProjector(
|
279 |
+
mm_projector_type=model_config.mm_projector, in_dim=vit_config["dim"], out_dim=model_config["dim"]
|
280 |
+
)
|
281 |
+
model_kwargs.update({"vision_encoder": vision_encoder, "mm_projector": mm_projector})
|
282 |
+
|
283 |
+
# Perform vocab expansion
|
284 |
+
if tokenizer.vocab_size > model.vocab_size:
|
285 |
+
log.debug(f"Expanding vocab size to {tokenizer.vocab_size}")
|
286 |
+
# For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer,
|
287 |
+
expand_output_layer = not (tokenizer.training_type == "text_to_video")
|
288 |
+
model.expand_vocab(
|
289 |
+
tokenizer.vocab_size,
|
290 |
+
init_method="gaussian",
|
291 |
+
expand_output_layer=expand_output_layer,
|
292 |
+
)
|
293 |
+
|
294 |
+
# Remove the "model." prefix in the state_dict
|
295 |
+
llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
|
296 |
+
with timer("loading state_dict into model"):
|
297 |
+
missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True)
|
298 |
+
# Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
|
299 |
+
missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
|
300 |
+
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
|
301 |
+
|
302 |
+
if model_config.vision_encoder is not None:
|
303 |
+
vision_encoder.load_state_dict(vit_checkpoint)
|
304 |
+
mm_projector.load_state_dict(projector_checkpoint)
|
305 |
+
if model_config.vision_encoder_in_channels != 3:
|
306 |
+
vision_encoder.expand_in_channels(model_config.vision_encoder_in_channels)
|
307 |
+
|
308 |
+
model = model.to(precision) # ensure model parameters are in the correct precision
|
309 |
+
log.debug(f"Model config: {model_config}")
|
310 |
+
|
311 |
+
model_class = AutoRegressiveModel
|
312 |
+
|
313 |
+
torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value
|
314 |
+
|
315 |
+
return model_class(model, tokenizer, model_config, **model_kwargs)
|
316 |
+
|
317 |
+
@torch.no_grad()
|
318 |
+
def generate(
|
319 |
+
self,
|
320 |
+
prompt_tokens: List[List[int]] | torch.Tensor,
|
321 |
+
max_gen_len: int,
|
322 |
+
temperature: float = 1.0,
|
323 |
+
top_k: Optional[int] = None,
|
324 |
+
top_p: Optional[float] = None,
|
325 |
+
num_gen_seq: int = 1,
|
326 |
+
logprobs: bool = False,
|
327 |
+
echo: bool = False,
|
328 |
+
seed: int = None,
|
329 |
+
context: Optional[torch.Tensor] = None,
|
330 |
+
context_mask: Optional[torch.Tensor] = None,
|
331 |
+
compile_sampling: bool = True,
|
332 |
+
compile_prefill: bool = False,
|
333 |
+
verbose: bool = True,
|
334 |
+
stop_tokens: Optional[Set[int]] = None,
|
335 |
+
images: Optional[torch.Tensor] = None,
|
336 |
+
):
|
337 |
+
"""
|
338 |
+
Autoregressive generation built upon the gpt-fast implementation (https://github.com/pytorch-labs/gpt-fast).
|
339 |
+
|
340 |
+
Args:
|
341 |
+
prompt_tokens (List[List[int]] | torch.Tensor): A single prompt of shape (1, seq_len).
|
342 |
+
max_gen_len (int): Maximum length of the generated text sequence.
|
343 |
+
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
|
344 |
+
top_k (int, optional): Top-k value for top-k sampling. Defaults to None.
|
345 |
+
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None.
|
346 |
+
num_gen_seq (int, optional): Number of outputs to generate given the same prompt. Defaults to 1. When temperature == 0, num_gen_seq must be 1 because the generation is deterministic.
|
347 |
+
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
|
348 |
+
logit_clipping_range (list, optional): Range of logits to clip. Defaults to [].
|
349 |
+
seed (int, optional): Random seed for reproducibility. Defaults to None.
|
350 |
+
compile_sampling (bool, optional): Flag indicating whether to compile the decoding function. Defaults to True.
|
351 |
+
compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False.
|
352 |
+
verbose (bool, optional): Flag indicating whether to print the the time. Defaults to False.
|
353 |
+
"""
|
354 |
+
assert top_k is None or top_p is None, f"Only one of top_k ({top_k} or top_p ({top_p} should be specified."
|
355 |
+
if temperature == 0:
|
356 |
+
top_p, top_k = None, None
|
357 |
+
log.debug("Setting top_p and top_k to None because temperature is 0")
|
358 |
+
if top_p is not None:
|
359 |
+
log.debug(f"Using top-p sampling with p={top_p} and temperature={temperature}")
|
360 |
+
elif top_k is not None:
|
361 |
+
log.debug(f"Using top-k sampling with k={top_k} and temperature={temperature}")
|
362 |
+
else:
|
363 |
+
log.debug("Not applying top-k or top-p sampling. Will use top-k sampling with k=None")
|
364 |
+
|
365 |
+
orig_precision = torch.get_default_dtype()
|
366 |
+
torch.set_default_dtype(self.precision)
|
367 |
+
|
368 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
369 |
+
torch._inductor.config.triton.unique_kernel_names = True
|
370 |
+
# Experimental features to reduce compilation times, will be on by default in future
|
371 |
+
torch._inductor.config.fx_graph_cache = True
|
372 |
+
|
373 |
+
if seed is not None:
|
374 |
+
misc.set_random_seed(seed)
|
375 |
+
|
376 |
+
assert not logprobs, "logprobs are not supported for fast_generate yet"
|
377 |
+
# Examine if the function prefil and decode_one_token functions are compiled yet. If not, compile them based on the flags
|
378 |
+
if compile_sampling and not getattr(self, "inference_decode_compiled", False):
|
379 |
+
self.decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
|
380 |
+
self.inference_decode_compiled = True
|
381 |
+
log.info("Compiled AR sampling function. Note: the first run will be slower due to compilation")
|
382 |
+
if compile_prefill and not getattr(self, "inference_prefill_compiled", False):
|
383 |
+
self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
|
384 |
+
self.inference_prefill_compiled = True
|
385 |
+
log.info("Compiled prefill function. Note: the first run will be slower due to compilation")
|
386 |
+
|
387 |
+
if not hasattr(self, "decode_one_token"):
|
388 |
+
self.decode_one_token = decode_one_token
|
389 |
+
if not hasattr(self, "prefill"):
|
390 |
+
self.prefill = prefill
|
391 |
+
|
392 |
+
# Initialization and Assertions
|
393 |
+
if isinstance(self.model.params, list):
|
394 |
+
# During training, model.params is a list
|
395 |
+
log.debug(
|
396 |
+
f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}"
|
397 |
+
)
|
398 |
+
params = self.config
|
399 |
+
else:
|
400 |
+
params = self.model.params
|
401 |
+
if isinstance(prompt_tokens, list):
|
402 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda")
|
403 |
+
if prompt_tokens.ndim == 1:
|
404 |
+
prompt_tokens = prompt_tokens.view(1, -1)
|
405 |
+
else:
|
406 |
+
assert prompt_tokens.ndim == 2, f"prompt_tokens has shape {prompt_tokens.shape}"
|
407 |
+
batch_size, prompt_len = prompt_tokens.shape
|
408 |
+
total_len = min(params.max_seq_len, max_gen_len + prompt_len)
|
409 |
+
if max_gen_len + prompt_len > params.max_seq_len:
|
410 |
+
log.warning(
|
411 |
+
f"max_gen_len + prompt_len={max_gen_len + prompt_len} exceeds max_seq_len={params.max_seq_len}, truncate max_gen_len to {params.max_seq_len - prompt_len}"
|
412 |
+
)
|
413 |
+
max_gen_len = params.max_seq_len - prompt_len
|
414 |
+
|
415 |
+
if context_mask is not None:
|
416 |
+
context_mask = context_mask.to(dtype=torch.bool)
|
417 |
+
if context_mask.ndim == 2:
|
418 |
+
assert (
|
419 |
+
context_mask.shape[0] == batch_size
|
420 |
+
), f"batch_size mismatch: {context_mask.shape[0]} != {batch_size}"
|
421 |
+
# Unsqueeze it to make it of shape [batch_size, 1, 1, context_seq_len]
|
422 |
+
context_mask = context_mask.view(batch_size, 1, 1, -1)
|
423 |
+
|
424 |
+
if num_gen_seq > 1:
|
425 |
+
assert (
|
426 |
+
batch_size == 1
|
427 |
+
), f"num_gen_seq > 1 is only supported for a single prompt, got {len(prompt_tokens)} prompts"
|
428 |
+
log.debug(f"Generating {num_gen_seq} sequences with the same prompt")
|
429 |
+
assert (
|
430 |
+
num_gen_seq <= params.max_batch_size
|
431 |
+
), f"num_gen_seq={num_gen_seq} exceeds max_batch_size={params.max_batch_size}"
|
432 |
+
# repeat the prompt tokens for num_gen_seq times
|
433 |
+
prompt_tokens = prompt_tokens.repeat(num_gen_seq, 1)
|
434 |
+
assert prompt_tokens.shape == (
|
435 |
+
num_gen_seq,
|
436 |
+
prompt_len,
|
437 |
+
), f"prompt_tokens must be of shape (num_gen_seq, seq_len), got {prompt_tokens.shape}"
|
438 |
+
batch_size = len(prompt_tokens)
|
439 |
+
|
440 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
441 |
+
empty = torch.empty(batch_size, total_len, dtype=prompt_tokens.dtype, device=prompt_tokens.device)
|
442 |
+
empty[:, :prompt_len] = prompt_tokens
|
443 |
+
seq = empty
|
444 |
+
input_pos = torch.arange(0, prompt_len, device="cuda")
|
445 |
+
|
446 |
+
if verbose:
|
447 |
+
prefill_start = time.time()
|
448 |
+
|
449 |
+
if images is not None:
|
450 |
+
images = images.to(device=prompt_tokens.device, dtype=torch.bfloat16)
|
451 |
+
prompt_token_embeddings = self.embed_vision_language_features(prompt_tokens, images)
|
452 |
+
else:
|
453 |
+
prompt_token_embeddings = None
|
454 |
+
|
455 |
+
if context is not None:
|
456 |
+
context = context.to(device=prompt_tokens.device, dtype=self.precision)
|
457 |
+
|
458 |
+
# Prefill stage
|
459 |
+
next_token = self.prefill(
|
460 |
+
self.model,
|
461 |
+
input_pos=input_pos,
|
462 |
+
tokens=prompt_tokens if prompt_token_embeddings is None else None,
|
463 |
+
token_embeddings=prompt_token_embeddings,
|
464 |
+
temperature=temperature,
|
465 |
+
top_k=top_k,
|
466 |
+
top_p=top_p,
|
467 |
+
context=context,
|
468 |
+
context_mask=context_mask,
|
469 |
+
)
|
470 |
+
if verbose:
|
471 |
+
prefill_time = time.time() - prefill_start
|
472 |
+
|
473 |
+
seq[:, [prompt_len]] = next_token.to(dtype=seq.dtype)
|
474 |
+
input_pos = torch.tensor([prompt_len], dtype=torch.long, device="cuda")
|
475 |
+
stop_tokens = self.tokenizer.stop_tokens if stop_tokens is None else stop_tokens
|
476 |
+
stop_tokens = torch.tensor(list(stop_tokens), dtype=torch.long, device="cuda")
|
477 |
+
|
478 |
+
if verbose:
|
479 |
+
decode_start = time.time()
|
480 |
+
# Decode stage
|
481 |
+
generated_tokens = decode_n_tokens(
|
482 |
+
self.model,
|
483 |
+
next_token.view(batch_size, -1),
|
484 |
+
input_pos,
|
485 |
+
max_gen_len - 1,
|
486 |
+
temperature=temperature,
|
487 |
+
top_k=top_k,
|
488 |
+
top_p=top_p,
|
489 |
+
stop_tokens=stop_tokens,
|
490 |
+
decode_one_token_function=self.decode_one_token,
|
491 |
+
context=context,
|
492 |
+
context_mask=context_mask,
|
493 |
+
)
|
494 |
+
gen_len = len(generated_tokens)
|
495 |
+
if verbose:
|
496 |
+
decode_time = time.time() - decode_start
|
497 |
+
prefill_throughput = prompt_len / prefill_time
|
498 |
+
decode_throughput = gen_len / decode_time
|
499 |
+
log.debug(f"[Prefill] Time: {prefill_time:.2f}s; Throughput: {prefill_throughput:.2f} tokens/s")
|
500 |
+
log.debug(f"[Decode] Time: {decode_time:.2f}s; Throughput: {decode_throughput:.2f} tokens/s")
|
501 |
+
|
502 |
+
generated_tokens = torch.cat(generated_tokens, dim=1)
|
503 |
+
|
504 |
+
log.debug(f"generated_tokens: {generated_tokens.shape}")
|
505 |
+
seq = seq[:, : prompt_len + 1 + gen_len]
|
506 |
+
seq[:, prompt_len + 1 :] = generated_tokens
|
507 |
+
if not echo:
|
508 |
+
seq = seq[:, prompt_len:]
|
509 |
+
|
510 |
+
torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value
|
511 |
+
|
512 |
+
return seq, None
|
513 |
+
|
514 |
+
def embed_vision_language_features(self, input_ids: torch.Tensor, images: torch.tensor) -> torch.Tensor:
|
515 |
+
"""
|
516 |
+
Embed vision and language features into a combined representation.
|
517 |
+
|
518 |
+
Args:
|
519 |
+
input_ids (torch.Tensor): Input token IDs.
|
520 |
+
images (torch.tensor): Input images.
|
521 |
+
|
522 |
+
Returns:
|
523 |
+
torch.Tensor: Combined vision-language features.
|
524 |
+
|
525 |
+
Raises:
|
526 |
+
AssertionError: If vision encoder or mm projector is not initialized,
|
527 |
+
or if dimensions mismatch.
|
528 |
+
"""
|
529 |
+
# Ensure vision encoder and mm projector are initialized
|
530 |
+
assert self.vision_encoder is not None
|
531 |
+
assert self.mm_projector is not None
|
532 |
+
|
533 |
+
# Get image token ID and validate it
|
534 |
+
image_token_id = self.vision_encoder.image_token_id
|
535 |
+
assert isinstance(image_token_id, int) and image_token_id >= 0, f"Invalid image_token_id: {image_token_id}"
|
536 |
+
|
537 |
+
# Identify text and image locations in the input
|
538 |
+
text_locations = input_ids != image_token_id
|
539 |
+
image_locations = input_ids == image_token_id
|
540 |
+
|
541 |
+
# Process text features
|
542 |
+
text_features = self.model.tok_embeddings(input_ids[text_locations])
|
543 |
+
|
544 |
+
# Process image features
|
545 |
+
images = images.to(device=text_features.device, dtype=text_features.dtype)
|
546 |
+
vit_outputs = self.vision_encoder(images)
|
547 |
+
image_features = self.mm_projector(vit_outputs)
|
548 |
+
|
549 |
+
# Get dimensions
|
550 |
+
B, seq_len = input_ids.shape
|
551 |
+
N_total = B * seq_len
|
552 |
+
N_txt, D_txt = text_features.shape
|
553 |
+
N_img, N_patch, D_img = image_features.shape
|
554 |
+
|
555 |
+
# Reshape image features
|
556 |
+
image_features = image_features.reshape(N_img * N_patch, D_img)
|
557 |
+
|
558 |
+
# Validate dimensions
|
559 |
+
assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}"
|
560 |
+
assert (
|
561 |
+
N_total == N_txt + N_img * N_patch
|
562 |
+
), f"seq_len {seq_len} should be equal to N_txt + N_img*N_Patch {(N_txt, N_img * N_patch, image_locations.sum().item())}"
|
563 |
+
|
564 |
+
# Combine text and image features
|
565 |
+
combined_features = torch.empty(
|
566 |
+
(B, seq_len, D_txt),
|
567 |
+
dtype=text_features.dtype,
|
568 |
+
device=text_features.device,
|
569 |
+
)
|
570 |
+
combined_features[text_locations, :] = text_features
|
571 |
+
combined_features[image_locations, :] = image_features
|
572 |
+
|
573 |
+
return combined_features
|
574 |
+
|
575 |
+
def state_dict(self, *args, **kwargs):
|
576 |
+
"""
|
577 |
+
Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8).
|
578 |
+
"""
|
579 |
+
state_dict = super().state_dict(*args, **kwargs)
|
580 |
+
return process_state_dict(state_dict)
|
581 |
+
|
582 |
+
def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False):
|
583 |
+
"""
|
584 |
+
Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by
|
585 |
+
TransformerEngine for FP8).
|
586 |
+
"""
|
587 |
+
state_dict = process_state_dict(state_dict)
|
588 |
+
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign)
|
589 |
+
actual_missing_keys = []
|
590 |
+
for key in missing_keys:
|
591 |
+
if not any(substring in key for substring in substrings_to_ignore):
|
592 |
+
actual_missing_keys.append(key)
|
593 |
+
if strict:
|
594 |
+
if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0:
|
595 |
+
raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}")
|
596 |
+
return _IncompatibleKeys(actual_missing_keys, unexpected_keys)
|
ar_modules_attention.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
from typing import Optional, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
from .ar_modules_embedding import RotaryPositionEmbedding
|
23 |
+
from .ar_modules_normalization import create_norm
|
24 |
+
|
25 |
+
|
26 |
+
class Attention(nn.Module):
|
27 |
+
"""
|
28 |
+
Attenion layer with KV cache.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
n_heads: int,
|
34 |
+
n_kv_heads: Union[int, None],
|
35 |
+
dim: int,
|
36 |
+
max_batch_size: int,
|
37 |
+
max_seq_len: int,
|
38 |
+
context_dim: Optional[int] = None,
|
39 |
+
use_qk_normalization: bool = False,
|
40 |
+
norm_type: str = "rmsnorm",
|
41 |
+
norm_eps: float = 1e-5,
|
42 |
+
causal_mask: Optional[bool] = True,
|
43 |
+
head_dim: Optional[int] = None,
|
44 |
+
fuse_qkv: bool = False,
|
45 |
+
precision: str = "bfloat16",
|
46 |
+
attn_type: str = "self",
|
47 |
+
):
|
48 |
+
"""
|
49 |
+
Initializes the GQA module.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
n_heads (int): The number of attention heads.
|
53 |
+
n_kv_heads (int, optional): The number of key-value attention heads. None defaults to n_heads.
|
54 |
+
dim (int): The dimensionality of the input and output.
|
55 |
+
max_batch_size (int): The maximum batch size.
|
56 |
+
max_seq_len (int): The maximum sequence length.
|
57 |
+
context_dim (int, optional): The dimensionality of the context for cross-attn. Defaults to None.
|
58 |
+
use_qk_normalization (bool, optional): Whether to apply QK normalization. Defaults to False.
|
59 |
+
norm_type (str, optional): The type of normalization layer. Defaults to "rmsnorm".
|
60 |
+
norm_eps (float, optional): The epsilon value for normalization. Defaults to 1e-5.
|
61 |
+
causal_mask (bool, optional): Whether to use causal mask. Defaults to True.
|
62 |
+
head_dim (int, optional): The dimensionality of each attention head. If None, defaults to dim // n_heads.
|
63 |
+
fuse_qkv (bool, optional): Whether to fuse QKV. Defaults to False.
|
64 |
+
precision (str, optional): The precision of the module. Defaults to "bfloat16".
|
65 |
+
attn_type (str, optional): The type of attention. Defaults to "self".
|
66 |
+
"""
|
67 |
+
super().__init__()
|
68 |
+
assert attn_type in ["self", "cross", "full"], f"Invalid attention type: {attn_type}"
|
69 |
+
self.attn_type = attn_type
|
70 |
+
context_dim = dim if context_dim is None else context_dim
|
71 |
+
|
72 |
+
self.dim = dim
|
73 |
+
self.context_dim = context_dim
|
74 |
+
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
|
75 |
+
self.n_local_kv_heads = self.n_kv_heads
|
76 |
+
self.n_local_heads = n_heads
|
77 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
78 |
+
self.head_dim = dim // n_heads if head_dim is None else head_dim
|
79 |
+
self.causal_mask = causal_mask
|
80 |
+
self.fuse_qkv = fuse_qkv
|
81 |
+
self.precision = precision
|
82 |
+
|
83 |
+
if fuse_qkv:
|
84 |
+
assert context_dim == dim, f"Fuse QKV requires context_dim ({context_dim}) to be equal to dim ({dim})"
|
85 |
+
self.total_local_head_dim = (self.n_local_heads + 2 * self.n_local_kv_heads) * self.head_dim
|
86 |
+
self.wqkv = nn.Linear(dim, self.total_local_head_dim, bias=False)
|
87 |
+
# Register hook to load fused QKV weights
|
88 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
89 |
+
else:
|
90 |
+
self.wq = nn.Linear(dim, self.n_local_heads * self.head_dim, bias=False)
|
91 |
+
self.wk = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False)
|
92 |
+
self.wv = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False)
|
93 |
+
self.wo = nn.Linear(self.n_local_heads * self.head_dim, dim, bias=False)
|
94 |
+
|
95 |
+
self.max_batch_size = max_batch_size
|
96 |
+
self.max_seq_len = max_seq_len
|
97 |
+
|
98 |
+
if self.attn_type == "self":
|
99 |
+
# Cache for key and value tensors
|
100 |
+
self.init_kv_cache()
|
101 |
+
|
102 |
+
# QK normalization layers
|
103 |
+
if use_qk_normalization:
|
104 |
+
self.q_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps)
|
105 |
+
self.k_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps)
|
106 |
+
|
107 |
+
self.use_qk_normalization = use_qk_normalization
|
108 |
+
|
109 |
+
self.to(dtype=getattr(torch, self.precision))
|
110 |
+
|
111 |
+
def load_hook(self, state_dict, prefix, *args):
|
112 |
+
if prefix + "wq.weight" in state_dict:
|
113 |
+
wq = state_dict.pop(prefix + "wq.weight")
|
114 |
+
wk = state_dict.pop(prefix + "wk.weight")
|
115 |
+
wv = state_dict.pop(prefix + "wv.weight")
|
116 |
+
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
117 |
+
|
118 |
+
def init_kv_cache(self, dtype=None):
|
119 |
+
cache_shape = (self.max_batch_size, self.n_local_kv_heads, self.max_seq_len, self.head_dim)
|
120 |
+
if dtype is None:
|
121 |
+
dtype = getattr(torch, self.precision)
|
122 |
+
if self.attn_type == "self":
|
123 |
+
self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda()
|
124 |
+
self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda()
|
125 |
+
|
126 |
+
def forward(
|
127 |
+
self,
|
128 |
+
x: torch.Tensor,
|
129 |
+
rope: RotaryPositionEmbedding,
|
130 |
+
input_pos: torch.Tensor,
|
131 |
+
mask: Optional[torch.Tensor] = None,
|
132 |
+
context: Optional[torch.Tensor] = None,
|
133 |
+
):
|
134 |
+
"""
|
135 |
+
Forward pass of GQA.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
x: The input tensor of shape (batch_size, seq_len, dim).
|
139 |
+
rope: The rotary positional embedding module.
|
140 |
+
input_pos: The starting position of the current sequence.
|
141 |
+
mask: The attention mask tensor.
|
142 |
+
context: The context tensor of shape (batch_size, context_len, dim).
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
The output tensor after applying GQA.
|
146 |
+
"""
|
147 |
+
bsz, seqlen, _ = x.shape
|
148 |
+
|
149 |
+
# Use one single module to handle both self-attn and cross-attn
|
150 |
+
context = x if context is None else context
|
151 |
+
context_len = seqlen if context is None else context.shape[1]
|
152 |
+
|
153 |
+
if self.fuse_qkv:
|
154 |
+
q_size = self.n_local_heads * self.head_dim
|
155 |
+
kv_size = self.n_local_kv_heads * self.head_dim
|
156 |
+
xq, xk, xv = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
|
157 |
+
else:
|
158 |
+
# Compute query, key, and value projections
|
159 |
+
xq, xk, xv = self.wq(x), self.wk(context), self.wv(context)
|
160 |
+
|
161 |
+
# Reshape projections
|
162 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
163 |
+
xk = xk.view(bsz, context_len, self.n_local_kv_heads, self.head_dim)
|
164 |
+
xv = xv.view(bsz, context_len, self.n_local_kv_heads, self.head_dim)
|
165 |
+
|
166 |
+
# QK normalization
|
167 |
+
if self.use_qk_normalization:
|
168 |
+
xq = self.q_norm(xq)
|
169 |
+
xk = self.k_norm(xk)
|
170 |
+
|
171 |
+
# Apply rotary positional embeddings to queries and keys
|
172 |
+
# Only apply RoPE to self-attention!
|
173 |
+
if self.attn_type in ["self", "full"]:
|
174 |
+
xq, xk = rope(xq, xk, input_pos, seqlen)
|
175 |
+
|
176 |
+
xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
|
177 |
+
# xq: (bs, n_local_heads, seqlen, head_dim)
|
178 |
+
# xk: (bs, n_kv_heads, cache_len + context_len, head_dim)
|
179 |
+
# xv: (bs, n_kv_heads, cache_len + context_len, head_dim)
|
180 |
+
if self.attn_type == "self":
|
181 |
+
# Update cache with current key and value tensors
|
182 |
+
assert input_pos is not None
|
183 |
+
self.cache_k[:bsz, :, input_pos] = xk
|
184 |
+
self.cache_v[:bsz, :, input_pos] = xv
|
185 |
+
keys, values = (
|
186 |
+
self.cache_k[:bsz, :, :],
|
187 |
+
self.cache_v[:bsz, :, :],
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
keys, values = xk, xv
|
191 |
+
|
192 |
+
# Repeat keys and values if necessary
|
193 |
+
keys = keys.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim)
|
194 |
+
values = values.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim)
|
195 |
+
|
196 |
+
# For self-attention, `is_causal` should be set to False when KV cache is pre-computed and used,
|
197 |
+
# since the masking is handled outside this attention module.
|
198 |
+
# For cross-attention, it's always full-attn without causal mask
|
199 |
+
is_causal = False
|
200 |
+
output = scaled_dot_product_attention(
|
201 |
+
xq,
|
202 |
+
keys,
|
203 |
+
values,
|
204 |
+
head_dim=self.head_dim,
|
205 |
+
mask=mask,
|
206 |
+
is_causal=is_causal,
|
207 |
+
dropout_p=0.0,
|
208 |
+
)
|
209 |
+
output = output.view(bsz, seqlen, -1)
|
210 |
+
output = self.wo(output)
|
211 |
+
return output
|
212 |
+
|
213 |
+
|
214 |
+
def scaled_dot_product_attention(
|
215 |
+
q: torch.Tensor,
|
216 |
+
k: torch.Tensor,
|
217 |
+
v: torch.Tensor,
|
218 |
+
head_dim: int,
|
219 |
+
mask: Optional[torch.Tensor] = None,
|
220 |
+
is_causal: Optional[bool] = None,
|
221 |
+
dropout_p: float = 0.0,
|
222 |
+
) -> torch.Tensor:
|
223 |
+
"""
|
224 |
+
PyTorch's native implementation of Flash Attention 2.
|
225 |
+
|
226 |
+
If `is_causal` is given, then the causal attention mask is applied accordingly:
|
227 |
+
- If `is_causal` is True, the standard upper-left causal attention masking is applied.
|
228 |
+
- If `is_causal` is False, no attention mask is applied, unless an explicit mask tensor is
|
229 |
+
provided (i.e., `mask is not None`).
|
230 |
+
|
231 |
+
If `is_causal` is not given (i.e., `is_causal is None`), then the attention mask is applied
|
232 |
+
based on the provided mask tensor:
|
233 |
+
- If no explicit attention mask is given (i.e., `mask is None`), `is_causal` is set to True,
|
234 |
+
leading to the standard upper-left causal attention masking.
|
235 |
+
- If an attention mask is given (i.e., `mask is not None`), the provided mask is used,
|
236 |
+
and `is_causal` is set to False.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
q (torch.Tensor): Query tensor
|
240 |
+
k (torch.Tensor): Key tensor
|
241 |
+
v (torch.Tensor): Value tensor
|
242 |
+
head_dim (int): Dimension of each attention head
|
243 |
+
mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
|
244 |
+
is_causal (Optional[bool], optional): Whether to apply causal attention mask. Defaults to None.
|
245 |
+
dropout_p (float, optional): Dropout rate. Defaults to 0.0.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
torch.Tensor: Output tensor after applying scaled dot-product attention
|
249 |
+
"""
|
250 |
+
scale = 1.0 / math.sqrt(head_dim)
|
251 |
+
if is_causal is None:
|
252 |
+
is_causal = mask is None
|
253 |
+
y = torch.nn.functional.scaled_dot_product_attention(
|
254 |
+
q,
|
255 |
+
k,
|
256 |
+
v,
|
257 |
+
attn_mask=mask,
|
258 |
+
dropout_p=dropout_p,
|
259 |
+
scale=scale,
|
260 |
+
is_causal=is_causal,
|
261 |
+
)
|
262 |
+
return y.transpose(1, 2).contiguous()
|
ar_modules_embedding.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
from typing import List, Optional, Tuple
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from einops import rearrange, repeat
|
22 |
+
|
23 |
+
|
24 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
25 |
+
"""
|
26 |
+
embed_dim: output dimension for each position
|
27 |
+
pos: a list of positions to be encoded: size (M,)
|
28 |
+
out: (M, D)
|
29 |
+
"""
|
30 |
+
assert embed_dim % 2 == 0
|
31 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
32 |
+
omega /= embed_dim / 2.0
|
33 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
34 |
+
|
35 |
+
pos = pos.reshape(-1) # (M,)
|
36 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
37 |
+
|
38 |
+
emb_sin = np.sin(out) # (M, D/2)
|
39 |
+
emb_cos = np.cos(out) # (M, D/2)
|
40 |
+
|
41 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
42 |
+
return emb
|
43 |
+
|
44 |
+
|
45 |
+
def _rotate_half_te(x: torch.Tensor) -> torch.Tensor:
|
46 |
+
"""
|
47 |
+
change sign so the last dimension becomes [-odd, +even].
|
48 |
+
Adopted from TransformerEngine.
|
49 |
+
Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py
|
50 |
+
"""
|
51 |
+
x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
|
52 |
+
x1, x2 = x.unbind(dim=-2)
|
53 |
+
return torch.cat((-x2, x1), dim=-1)
|
54 |
+
|
55 |
+
|
56 |
+
def _apply_rotary_pos_emb_te(
|
57 |
+
t: torch.Tensor,
|
58 |
+
cos_freqs: torch.Tensor,
|
59 |
+
sin_freqs: torch.Tensor,
|
60 |
+
) -> torch.Tensor:
|
61 |
+
"""
|
62 |
+
Apply rotary positional embedding tensor to the input tensor.
|
63 |
+
Adopted from TransformerEngine.
|
64 |
+
Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py
|
65 |
+
|
66 |
+
Parameters
|
67 |
+
----------
|
68 |
+
t: torch.Tensor
|
69 |
+
Input tensor of shape `[b, s, h, d]`, on which
|
70 |
+
rotary positional embedding will be applied.
|
71 |
+
cos_freqs: torch.Tensor
|
72 |
+
Cosine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float',
|
73 |
+
sin_freqs: torch.Tensor
|
74 |
+
Sine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float',
|
75 |
+
"""
|
76 |
+
rot_dim = cos_freqs.shape[-1]
|
77 |
+
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
|
78 |
+
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
|
79 |
+
# first part is cosine component
|
80 |
+
# second part is sine component, need to change signs with _rotate_half method
|
81 |
+
t = (t * cos_freqs) + (_rotate_half_te(t) * sin_freqs)
|
82 |
+
output = torch.cat((t, t_pass), dim=-1)
|
83 |
+
return output
|
84 |
+
|
85 |
+
|
86 |
+
class RotaryPositionEmbedding(torch.nn.Module):
|
87 |
+
"""
|
88 |
+
Rotary Position Embedding module as described in the paper:
|
89 |
+
https://arxiv.org/abs/2104.09864
|
90 |
+
|
91 |
+
This module implements rotary positional embeddings, which are used to
|
92 |
+
enhance the performance of transformer models.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
dim (int): Dimensionality of the input tensor.
|
96 |
+
max_position_embeddings (Optional[int]): Maximum position embeddings.
|
97 |
+
original_max_position_embeddings (Optional[int]): Original maximum position embeddings.
|
98 |
+
rope_theta (Optional[float]): Base for the frequency calculation.
|
99 |
+
apply_yarn (Optional[bool]): Whether to apply YaRN (Yet another Rotary).
|
100 |
+
scale (Optional[int]): Scaling factor for the frequency calculation.
|
101 |
+
extrapolation_factor (Optional[int]): Extrapolation factor for the frequency extension.
|
102 |
+
attn_factor (Optional[int]): Attention factor for the frequency calculation.
|
103 |
+
beta_fast (Optional[int]): Fast beta value for the YaRN frequency calculation.
|
104 |
+
beta_slow (Optional[int]): Slow beta value for the YaRN frequency calculation.
|
105 |
+
rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D".
|
106 |
+
latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs.
|
107 |
+
original_latent_shape (Optional[List[int]]): Original shape of the latent tensor for video or image inputs.
|
108 |
+
pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
dim: int,
|
114 |
+
max_position_embeddings: Optional[int] = None,
|
115 |
+
original_max_position_embeddings: Optional[int] = None,
|
116 |
+
rope_theta: Optional[float] = 10000.0,
|
117 |
+
apply_yarn: Optional[bool] = False,
|
118 |
+
scale: Optional[int] = None,
|
119 |
+
extrapolation_factor: Optional[int] = 1,
|
120 |
+
attn_factor: Optional[int] = 1,
|
121 |
+
beta_fast: Optional[int] = 32,
|
122 |
+
beta_slow: Optional[int] = 1,
|
123 |
+
rope_dim: Optional[str] = "1D",
|
124 |
+
latent_shape: Optional[List[int]] = None,
|
125 |
+
original_latent_shape: Optional[List[int]] = None,
|
126 |
+
pad_to_multiple_of: Optional[int] = None,
|
127 |
+
):
|
128 |
+
super().__init__()
|
129 |
+
|
130 |
+
self.dim = dim
|
131 |
+
self.max_position_embeddings = max_position_embeddings
|
132 |
+
self.original_max_position_embeddings = original_max_position_embeddings
|
133 |
+
self.rope_theta = rope_theta
|
134 |
+
self.apply_yarn = apply_yarn
|
135 |
+
self.scale = scale
|
136 |
+
self.extrapolation_factor = extrapolation_factor
|
137 |
+
self.attn_factor = attn_factor
|
138 |
+
self.beta_fast = beta_fast
|
139 |
+
self.beta_slow = beta_slow
|
140 |
+
self.mscale = 1.0
|
141 |
+
self.rope_dim = rope_dim
|
142 |
+
self.latent_shape = latent_shape
|
143 |
+
self.original_latent_shape = original_latent_shape
|
144 |
+
self.pad_to_multiple_of = pad_to_multiple_of
|
145 |
+
self.get_inv_freq(torch.cuda.current_device())
|
146 |
+
|
147 |
+
def get_mscale(self, scale: float = 1.0) -> float:
|
148 |
+
"""Get the magnitude scaling factor for YaRN."""
|
149 |
+
if scale <= 1:
|
150 |
+
return 1.0
|
151 |
+
return 0.1 * math.log(scale) + 1.0
|
152 |
+
|
153 |
+
def forward(self, seq_len: Optional[int] = None) -> torch.Tensor:
|
154 |
+
"""
|
155 |
+
Forward pass for the rotary position embedding.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
seq_len (Optional[int]): Length of the sequence.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
torch.Tensor: The computed frequencies for positional embedding.
|
162 |
+
"""
|
163 |
+
|
164 |
+
if self.apply_yarn and seq_len > self.max_seq_len_cached:
|
165 |
+
self.max_seq_len_cached = seq_len
|
166 |
+
self.freqs = self.compute_freqs()
|
167 |
+
|
168 |
+
return self.freqs
|
169 |
+
|
170 |
+
def compute_freqs(
|
171 |
+
self,
|
172 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
173 |
+
"""Compute the spatial frequencies for the latent tensor."""
|
174 |
+
self.seq = torch.arange(self.max_seq_len_cached, dtype=torch.float).cuda()
|
175 |
+
if self.rope_dim == "1D":
|
176 |
+
emb = torch.einsum("i,j->ij", self.seq, self.inv_freq)
|
177 |
+
|
178 |
+
elif self.rope_dim == "2D":
|
179 |
+
H, W = self.latent_shape
|
180 |
+
half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq)
|
181 |
+
half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq)
|
182 |
+
emb = torch.cat(
|
183 |
+
[
|
184 |
+
repeat(half_emb_h, "h d -> h w d", w=W),
|
185 |
+
repeat(half_emb_w, "w d -> h w d", h=H),
|
186 |
+
]
|
187 |
+
* 2,
|
188 |
+
dim=-1,
|
189 |
+
)
|
190 |
+
emb = rearrange(emb, "h w d -> (h w) 1 1 d").float()
|
191 |
+
|
192 |
+
elif self.rope_dim == "3D":
|
193 |
+
T, H, W = self.latent_shape
|
194 |
+
half_emb_t = torch.outer(self.seq[:T], self.temporal_inv_freq)
|
195 |
+
half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq)
|
196 |
+
half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq)
|
197 |
+
emb = torch.cat(
|
198 |
+
[
|
199 |
+
repeat(half_emb_t, "t d -> t h w d", h=H, w=W),
|
200 |
+
repeat(half_emb_h, "h d -> t h w d", t=T, w=W),
|
201 |
+
repeat(half_emb_w, "w d -> t h w d", t=T, h=H),
|
202 |
+
]
|
203 |
+
* 2,
|
204 |
+
dim=-1,
|
205 |
+
)
|
206 |
+
emb = rearrange(emb, "t h w d -> (t h w) 1 1 d").float()
|
207 |
+
else:
|
208 |
+
raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}")
|
209 |
+
return emb
|
210 |
+
|
211 |
+
def get_scale_factors(self, inv_freq: torch.Tensor, original_seq_len: int) -> torch.Tensor:
|
212 |
+
"""Get the scale factors for YaRN."""
|
213 |
+
# Calculate the high and low frequency cutoffs for YaRN. Note: `beta_fast` and `beta_slow` are called
|
214 |
+
# `high_freq_factor` and `low_freq_factor` in the Llama 3.1 RoPE scaling code.
|
215 |
+
high_freq_cutoff = 2 * math.pi * self.beta_fast / original_seq_len
|
216 |
+
low_freq_cutoff = 2 * math.pi * self.beta_slow / original_seq_len
|
217 |
+
# Obtain a smooth mask that has a value of 0 for low frequencies and 1 for high frequencies, with linear
|
218 |
+
# interpolation in between.
|
219 |
+
smooth_mask = torch.clamp((inv_freq - low_freq_cutoff) / (high_freq_cutoff - low_freq_cutoff), min=0, max=1)
|
220 |
+
# For low frequencies, we scale the frequency by 1/self.scale. For high frequencies, we keep the frequency.
|
221 |
+
scale_factors = (1 - smooth_mask) / self.scale + smooth_mask
|
222 |
+
return scale_factors
|
223 |
+
|
224 |
+
def get_inv_freq(self, device: torch.device) -> None:
|
225 |
+
"""Get the inverse frequency."""
|
226 |
+
if self.rope_dim == "1D":
|
227 |
+
assert self.max_position_embeddings is not None, "Max position embeddings required."
|
228 |
+
inv_freq = 1.0 / (
|
229 |
+
self.rope_theta ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)
|
230 |
+
)
|
231 |
+
if self.apply_yarn:
|
232 |
+
assert self.original_max_position_embeddings is not None, "Original max position embeddings required."
|
233 |
+
assert self.beta_slow is not None, "Beta slow value required."
|
234 |
+
assert self.beta_fast is not None, "Beta fast value required."
|
235 |
+
|
236 |
+
scale_factors = self.get_scale_factors(inv_freq, self.original_max_position_embeddings)
|
237 |
+
# Apply the scaling factors to inv_freq.
|
238 |
+
inv_freq = inv_freq * scale_factors
|
239 |
+
# Set the magnitude scaling factor.
|
240 |
+
self.mscale = float(self.get_mscale(self.scale) * self.attn_factor)
|
241 |
+
self.max_seq_len_cached = self.max_position_embeddings
|
242 |
+
self.inv_freq = inv_freq
|
243 |
+
|
244 |
+
elif self.rope_dim == "2D":
|
245 |
+
assert self.latent_shape is not None, "Latent shape required."
|
246 |
+
dim_h = self.dim // 2
|
247 |
+
spatial_inv_freq = 1.0 / (
|
248 |
+
self.rope_theta ** torch.arange(0, dim_h, 2, dtype=torch.float32, device=device) / dim_h
|
249 |
+
)
|
250 |
+
if self.apply_yarn:
|
251 |
+
assert self.original_latent_shape is not None, "Original latent shape required."
|
252 |
+
assert self.beta_slow is not None, "Beta slow value required."
|
253 |
+
assert self.beta_fast is not None, "Beta fast value required."
|
254 |
+
|
255 |
+
scale_factors = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[0])
|
256 |
+
spatial_inv_freq = spatial_inv_freq * scale_factors
|
257 |
+
self.mscale = float(self.get_mscale(self.scale) * self.attn_factor)
|
258 |
+
self.spatial_inv_freq = spatial_inv_freq
|
259 |
+
self.max_seq_len_cached = max(self.latent_shape)
|
260 |
+
|
261 |
+
elif self.rope_dim == "3D":
|
262 |
+
assert self.latent_shape is not None, "Latent shape required."
|
263 |
+
dim_h = self.dim // 6 * 2
|
264 |
+
dim_t = self.dim - 2 * dim_h
|
265 |
+
self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(device) / dim_h
|
266 |
+
spatial_inv_freq = 1.0 / (self.rope_theta**self.dim_spatial_range)
|
267 |
+
self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(device) / dim_t
|
268 |
+
temporal_inv_freq = 1.0 / (self.rope_theta**self.dim_temporal_range)
|
269 |
+
if self.apply_yarn:
|
270 |
+
assert self.original_latent_shape is not None, "Original latent shape required."
|
271 |
+
assert self.beta_slow is not None, "Beta slow value required."
|
272 |
+
assert self.beta_fast is not None, "Beta fast value required."
|
273 |
+
scale_factors_spatial = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[1])
|
274 |
+
spatial_inv_freq = spatial_inv_freq * scale_factors_spatial
|
275 |
+
scale_factors_temporal = self.get_scale_factors(temporal_inv_freq, self.original_latent_shape[0])
|
276 |
+
temporal_inv_freq = temporal_inv_freq * scale_factors_temporal
|
277 |
+
self.mscale = float(self.get_mscale(self.scale) * self.attn_factor)
|
278 |
+
self.spatial_inv_freq = spatial_inv_freq
|
279 |
+
self.temporal_inv_freq = temporal_inv_freq
|
280 |
+
self.max_seq_len_cached = max(self.latent_shape)
|
281 |
+
else:
|
282 |
+
raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}")
|
283 |
+
|
284 |
+
self.freqs = self.compute_freqs()
|
285 |
+
|
286 |
+
|
287 |
+
class RotaryPositionEmbeddingPytorchV2(RotaryPositionEmbedding):
|
288 |
+
"""
|
289 |
+
Rotary Position Embedding that works in the same way as the TransformerEngine RoPE
|
290 |
+
(https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py)
|
291 |
+
|
292 |
+
"""
|
293 |
+
|
294 |
+
def __init__(
|
295 |
+
self,
|
296 |
+
seq_len: int,
|
297 |
+
training_type: str = None,
|
298 |
+
**kwargs,
|
299 |
+
):
|
300 |
+
super().__init__(
|
301 |
+
**kwargs,
|
302 |
+
)
|
303 |
+
emb = self.create_rope_freqs(seq_len=seq_len, training_type=training_type)
|
304 |
+
emb = emb.transpose(0, 1).contiguous() # [seq, 1, 1, dim] -> [1, seq, 1, dim]
|
305 |
+
assert emb.shape[0] == 1 and emb.shape[2] == 1, f"emb shape: {emb.shape}"
|
306 |
+
# cos/sin first then dtype conversion for better precision
|
307 |
+
self.register_buffer("cos_cached", torch.cos(emb), persistent=False)
|
308 |
+
self.register_buffer("sin_cached", torch.sin(emb), persistent=False)
|
309 |
+
|
310 |
+
def create_rope_freqs(self, seq_len: int, training_type: str = None) -> torch.Tensor:
|
311 |
+
"""
|
312 |
+
Create rotary position embedding frequencies.
|
313 |
+
|
314 |
+
Args:
|
315 |
+
seq_len (int): Sequence length of a sample.
|
316 |
+
|
317 |
+
Returns:
|
318 |
+
torch.Tensor: The computed positional embeddings.
|
319 |
+
"""
|
320 |
+
if self.rope_dim == "1D":
|
321 |
+
freqs = super().forward(seq_len=seq_len)
|
322 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
323 |
+
emb = emb.reshape(emb.size(0), 1, 1, emb.size(1))
|
324 |
+
|
325 |
+
elif self.rope_dim in ["2D", "3D"]:
|
326 |
+
emb = super().forward(seq_len=seq_len)
|
327 |
+
if training_type == "text_to_video":
|
328 |
+
# since we added <bov> token at the beginning of the video for text2world, we also extend the position embedding by one token in the beginning
|
329 |
+
bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device)
|
330 |
+
emb = torch.cat((bov_pe, emb), dim=0)
|
331 |
+
else:
|
332 |
+
raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}")
|
333 |
+
if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0:
|
334 |
+
# Round up to the nearest multiple of pad_to_multiple_of
|
335 |
+
pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of
|
336 |
+
emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device)), dim=0)
|
337 |
+
|
338 |
+
return emb
|
339 |
+
|
340 |
+
def forward(
|
341 |
+
self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None
|
342 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
343 |
+
if q.dtype != self.cos_cached.dtype:
|
344 |
+
self.cos_cached = self.cos_cached.to(q.dtype)
|
345 |
+
self.sin_cached = self.sin_cached.to(q.dtype)
|
346 |
+
|
347 |
+
cos_emb = self.cos_cached
|
348 |
+
sin_emb = self.sin_cached
|
349 |
+
if input_pos is not None:
|
350 |
+
cos_emb = cos_emb[:, input_pos, :, :]
|
351 |
+
sin_emb = sin_emb[:, input_pos, :, :]
|
352 |
+
elif seq_len is not None:
|
353 |
+
cos_emb = cos_emb[:, :seq_len, :, :]
|
354 |
+
sin_emb = sin_emb[:, :seq_len, :, :]
|
355 |
+
q = _apply_rotary_pos_emb_te(q, cos_emb, sin_emb)
|
356 |
+
k = _apply_rotary_pos_emb_te(k, cos_emb, sin_emb)
|
357 |
+
return q, k
|
358 |
+
|
359 |
+
|
360 |
+
class RotaryPositionEmbeddingPytorchV1(RotaryPositionEmbedding):
|
361 |
+
"""
|
362 |
+
Rotary Position Embedding that works in the same way as
|
363 |
+
mistral_inference (https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/rope.py)
|
364 |
+
or llama3 (https://github.com/meta-llama/llama3/blob/main/llama/model.py)
|
365 |
+
|
366 |
+
"""
|
367 |
+
|
368 |
+
def __init__(
|
369 |
+
self,
|
370 |
+
**kwargs,
|
371 |
+
):
|
372 |
+
super().__init__(
|
373 |
+
**kwargs,
|
374 |
+
)
|
375 |
+
if self.rope_dim == "1D":
|
376 |
+
emb = torch.stack((self.freqs, self.freqs), dim=-1).reshape(*self.freqs.shape[:-1], -1)
|
377 |
+
elif self.rope_dim in ["2D", "3D"]:
|
378 |
+
emb = rearrange(self.freqs, "s 1 1 d -> s d").float()
|
379 |
+
self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, :, None, :], persistent=False)
|
380 |
+
self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, :, None, :], persistent=False)
|
381 |
+
|
382 |
+
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
383 |
+
"""Rotate half the hidden dimensions of the input tensor."""
|
384 |
+
x_reshaped = x.reshape(*x.shape[:-1], -1, 2)
|
385 |
+
x1 = x_reshaped[..., 0]
|
386 |
+
x2 = x_reshaped[..., 1]
|
387 |
+
output = torch.stack((-x2, x1), dim=-1).reshape(*x.shape)
|
388 |
+
return output
|
389 |
+
|
390 |
+
def forward(
|
391 |
+
self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None
|
392 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
393 |
+
"""
|
394 |
+
Forward pass for the rotary position embedding.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
q (torch.Tensor): Query tensor.
|
398 |
+
k (torch.Tensor): Key tensor.
|
399 |
+
input_pos (Optional[torch.Tensor]): Starting position for the sequence.
|
400 |
+
seq_len (Optional[int]): Length of the sequence.
|
401 |
+
|
402 |
+
Returns:
|
403 |
+
Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors.
|
404 |
+
"""
|
405 |
+
if self.apply_yarn and seq_len > self.max_seq_len_cached:
|
406 |
+
freqs = super().forward(seq_len)
|
407 |
+
if self.rope_dim == "1D":
|
408 |
+
emb = torch.stack((freqs, freqs), dim=-1).reshape(*freqs.shape[:-1], -1)
|
409 |
+
elif self.rope_dim in ["2D", "3D"]:
|
410 |
+
emb = rearrange(freqs, "s 1 1 d -> s d").float()
|
411 |
+
else:
|
412 |
+
raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}")
|
413 |
+
self.register_buffer(
|
414 |
+
"cos_cached", (emb.cos() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False
|
415 |
+
)
|
416 |
+
self.register_buffer(
|
417 |
+
"sin_cached", (emb.sin() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False
|
418 |
+
)
|
419 |
+
|
420 |
+
if input_pos is not None:
|
421 |
+
cos_cached = self.cos_cached[:, input_pos]
|
422 |
+
sin_cached = self.sin_cached[:, input_pos]
|
423 |
+
else:
|
424 |
+
assert (
|
425 |
+
self.cos_cached.shape[1] >= seq_len
|
426 |
+
), f"Invalid sequence length; cos_cached.shape {self.cos_cached.shape}, seq_len {seq_len}."
|
427 |
+
cos_cached = self.cos_cached[:, :seq_len, ...]
|
428 |
+
sin_cached = self.sin_cached[:, :seq_len, ...]
|
429 |
+
xq = q * cos_cached + self.rotate_half(q) * sin_cached
|
430 |
+
xk = k * cos_cached + self.rotate_half(k) * sin_cached
|
431 |
+
|
432 |
+
return xq.type_as(q), xk.type_as(k)
|
433 |
+
|
434 |
+
|
435 |
+
class SinCosPosEmbAxisTE(torch.nn.Module):
|
436 |
+
def __init__(
|
437 |
+
self,
|
438 |
+
dim: int,
|
439 |
+
latent_shape: Optional[List[int]] = None,
|
440 |
+
pad_to_multiple_of: Optional[int] = None,
|
441 |
+
dtype: torch.dtype = torch.bfloat16,
|
442 |
+
**kwargs,
|
443 |
+
):
|
444 |
+
"""
|
445 |
+
Args:
|
446 |
+
dim (int): Dimensionality of the input tensor.
|
447 |
+
latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs.
|
448 |
+
pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
|
449 |
+
dtype (torch.dtype): Data type of the position embedding tensor.
|
450 |
+
"""
|
451 |
+
super().__init__()
|
452 |
+
dim_h = dim // 6 * 2
|
453 |
+
dim_w = dim_h
|
454 |
+
dim_t = dim - 2 * dim_h
|
455 |
+
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
|
456 |
+
self.latent_shape = latent_shape
|
457 |
+
T, H, W = latent_shape
|
458 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(H))
|
459 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(W))
|
460 |
+
emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(T))
|
461 |
+
|
462 |
+
self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).to(dtype=dtype, device="cuda"), persistent=False)
|
463 |
+
self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).to(dtype=dtype, device="cuda"), persistent=False)
|
464 |
+
self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).to(dtype=dtype, device="cuda"), persistent=False)
|
465 |
+
self.pad_to_multiple_of = pad_to_multiple_of
|
466 |
+
|
467 |
+
def forward(
|
468 |
+
self,
|
469 |
+
training_type: str = None,
|
470 |
+
) -> torch.Tensor:
|
471 |
+
T, H, W = self.latent_shape
|
472 |
+
emb = torch.cat(
|
473 |
+
[
|
474 |
+
repeat(self.pos_emb_t, "t d-> t h w d", h=H, w=W),
|
475 |
+
repeat(self.pos_emb_h, "h d-> t h w d", t=T, w=W),
|
476 |
+
repeat(self.pos_emb_w, "w d-> t h w d", t=T, h=H),
|
477 |
+
],
|
478 |
+
dim=-1,
|
479 |
+
)
|
480 |
+
# Flatten the T,H,W dimensions
|
481 |
+
emb = rearrange(emb, "t h w d -> (t h w) d")
|
482 |
+
|
483 |
+
if training_type == "text_to_video":
|
484 |
+
bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device, dtype=emb.dtype)
|
485 |
+
emb = torch.cat((bov_pe, emb), dim=0)
|
486 |
+
if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0:
|
487 |
+
pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of
|
488 |
+
emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device, dtype=emb.dtype)), dim=0)
|
489 |
+
seq_len, dim = emb.shape
|
490 |
+
emb = emb.reshape(1, seq_len, dim)
|
491 |
+
return emb
|
ar_modules_mlp.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
|
21 |
+
class MLP(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
dim: int,
|
25 |
+
hidden_dim: int,
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
Initializes the multilayer perceptron (MLP) module.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
dim: The input and output dimensionality.
|
32 |
+
hidden_dim: The dimensionality of the hidden layer.
|
33 |
+
"""
|
34 |
+
super().__init__()
|
35 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
36 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
37 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
+
"""
|
41 |
+
Performs the forward pass of the MLP module.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
x: The input tensor of shape (batch_size, dim).
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
The output tensor of shape (batch_size, dim).
|
48 |
+
"""
|
49 |
+
output = self.w2(F.silu(self.w1(x)) * self.w3(x))
|
50 |
+
return output
|
ar_modules_normalization.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
|
21 |
+
"""
|
22 |
+
Creates the specified normalization layer based on the norm_type.
|
23 |
+
Adopted from TorchTriton: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
|
24 |
+
|
25 |
+
Args:
|
26 |
+
norm_type (str): The type of normalization layer to create.
|
27 |
+
Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm
|
28 |
+
dim (int): The dimension of the normalization layer.
|
29 |
+
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
The created normalization layer.
|
33 |
+
|
34 |
+
Raises:
|
35 |
+
NotImplementedError: If an unknown norm_type is provided.
|
36 |
+
"""
|
37 |
+
norm_type = norm_type.lower() # Normalize to lowercase
|
38 |
+
|
39 |
+
if norm_type == "layernorm":
|
40 |
+
return nn.LayerNorm(dim, eps=eps, bias=False)
|
41 |
+
elif norm_type == "np_layernorm":
|
42 |
+
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
|
43 |
+
elif norm_type == "rmsnorm":
|
44 |
+
return RMSNorm(dim, eps=eps, compile=False)
|
45 |
+
elif norm_type == "compiled_rmsnorm":
|
46 |
+
return RMSNorm(dim, eps=eps, compile=True)
|
47 |
+
elif norm_type == "fused_rmsnorm":
|
48 |
+
raise NotImplementedError("Fused RMSNorm is not supported yet.")
|
49 |
+
else:
|
50 |
+
raise NotImplementedError(f"Unknown norm_type: '{norm_type}'")
|
51 |
+
|
52 |
+
|
53 |
+
class RMSNorm(nn.Module):
|
54 |
+
"""
|
55 |
+
Initialize the RMSNorm normalization layer.
|
56 |
+
Reference implementation: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
|
57 |
+
|
58 |
+
Args:
|
59 |
+
dim (int): The dimension of the input tensor.
|
60 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
61 |
+
compile (bool, optional): Whether to compile the forward function. Default is False.
|
62 |
+
|
63 |
+
Attributes:
|
64 |
+
eps (float): A small value added to the denominator for numerical stability.
|
65 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
66 |
+
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False):
|
70 |
+
super().__init__()
|
71 |
+
self.eps = eps
|
72 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
73 |
+
self.rmsnorm_fn = torch.compile(self.compute_rmsnorm, fullgraph=True) if compile else self.compute_rmsnorm
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float):
|
77 |
+
def _norm(x, eps):
|
78 |
+
# Computes the root-mean-square norm of the input tensor.
|
79 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
80 |
+
|
81 |
+
output = _norm(x.float(), eps).type_as(x)
|
82 |
+
return output * weight
|
83 |
+
|
84 |
+
def forward(self, x: torch.Tensor):
|
85 |
+
return self.rmsnorm_fn(x, self.weight, self.eps)
|
86 |
+
|
87 |
+
def reset_parameters(self):
|
88 |
+
torch.nn.init.ones_(self.weight)
|
ar_networks.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from collections import namedtuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from .ar_tokenizer_modules import CausalConv3d, DecoderFactorized, EncoderFactorized
|
22 |
+
from .ar_tokenizer_quantizers import FSQuantizer
|
23 |
+
from .log import log
|
24 |
+
|
25 |
+
NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"])
|
26 |
+
|
27 |
+
|
28 |
+
class CausalDiscreteVideoTokenizer(nn.Module):
|
29 |
+
def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None:
|
30 |
+
super().__init__()
|
31 |
+
self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer")
|
32 |
+
self.embedding_dim = embedding_dim
|
33 |
+
self.encoder = EncoderFactorized(z_channels=z_factor * z_channels, **kwargs)
|
34 |
+
self.decoder = DecoderFactorized(z_channels=z_channels, **kwargs)
|
35 |
+
|
36 |
+
self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0)
|
37 |
+
self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0)
|
38 |
+
|
39 |
+
self.quantizer = FSQuantizer(**kwargs)
|
40 |
+
|
41 |
+
num_parameters = sum(param.numel() for param in self.parameters())
|
42 |
+
log.debug(f"model={self.name}, num_parameters={num_parameters:,}")
|
43 |
+
log.debug(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.")
|
44 |
+
|
45 |
+
def to(self, *args, **kwargs):
|
46 |
+
setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16))
|
47 |
+
return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs)
|
48 |
+
|
49 |
+
def encode(self, x):
|
50 |
+
h = self.encoder(x)
|
51 |
+
h = self.quant_conv(h)
|
52 |
+
return self.quantizer(h)
|
53 |
+
|
54 |
+
def decode(self, quant):
|
55 |
+
quant = self.post_quant_conv(quant)
|
56 |
+
return self.decoder(quant)
|
57 |
+
|
58 |
+
def forward(self, input):
|
59 |
+
quant_info, quant_codes, quant_loss = self.encode(input)
|
60 |
+
reconstructions = self.decode(quant_codes)
|
61 |
+
if self.training:
|
62 |
+
return dict(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info)
|
63 |
+
return NetworkEval(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info)
|
ar_tokenizer.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from collections import defaultdict
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from einops import rearrange
|
21 |
+
|
22 |
+
from .ar_config_tokenizer import TokenizerConfig
|
23 |
+
from .lazy_config_init import instantiate as lazy_instantiate
|
24 |
+
|
25 |
+
|
26 |
+
def update_vocab_size(
|
27 |
+
existing_vocab_size,
|
28 |
+
to_be_added_vocab_size,
|
29 |
+
training_type,
|
30 |
+
add_special_tokens,
|
31 |
+
video_special_tokens={},
|
32 |
+
):
|
33 |
+
# New vocab size
|
34 |
+
if add_special_tokens:
|
35 |
+
existing_vocab_size += to_be_added_vocab_size + len(video_special_tokens)
|
36 |
+
# For text_to_video, we add one <bov> special token at the beginning of the video
|
37 |
+
elif training_type == "text_to_video":
|
38 |
+
existing_vocab_size += to_be_added_vocab_size + 1
|
39 |
+
else:
|
40 |
+
existing_vocab_size += to_be_added_vocab_size
|
41 |
+
return existing_vocab_size
|
42 |
+
|
43 |
+
|
44 |
+
class DiscreteMultimodalTokenizer:
|
45 |
+
def __init__(self, tokenizer_config: TokenizerConfig):
|
46 |
+
self.tokenizer_config = tokenizer_config
|
47 |
+
self.vocab_size = 0
|
48 |
+
self.total_seq_len = tokenizer_config.seq_len
|
49 |
+
self.pad_to_multiple_of = tokenizer_config.pad_to_multiple_of
|
50 |
+
self.training_type = tokenizer_config.training_type
|
51 |
+
assert self.training_type in [
|
52 |
+
"text_only",
|
53 |
+
"text_to_video",
|
54 |
+
"video_to_video",
|
55 |
+
"image_text_interleaved",
|
56 |
+
], f"{self.training_type} not supported"
|
57 |
+
|
58 |
+
self._build_text_tokenizer()
|
59 |
+
self._build_video_tokenizer()
|
60 |
+
|
61 |
+
def _build_text_tokenizer(self):
|
62 |
+
r"""Function to initialize the text tokenizer model."""
|
63 |
+
if self.tokenizer_config.text_tokenizer is not None:
|
64 |
+
self.text_tokenizer = lazy_instantiate(self.tokenizer_config.text_tokenizer.config)
|
65 |
+
self.vocab_size += self.tokenizer_config.text_tokenizer.vocab_size
|
66 |
+
else:
|
67 |
+
self.text_tokenizer = None
|
68 |
+
|
69 |
+
def _build_video_tokenizer(self):
|
70 |
+
r"""Function to initialize the video tokenizer model."""
|
71 |
+
if self.tokenizer_config.video_tokenizer is not None:
|
72 |
+
self.video_tokenizer = lazy_instantiate(self.tokenizer_config.video_tokenizer.config)
|
73 |
+
self.video_tokenizer = self.video_tokenizer.to("cuda")
|
74 |
+
self.video_vocab_size = self.tokenizer_config.video_tokenizer.vocab_size
|
75 |
+
special_token_offset = (
|
76 |
+
self.tokenizer_config.video_tokenizer.tokenizer_offset
|
77 |
+
+ self.tokenizer_config.video_tokenizer.vocab_size
|
78 |
+
)
|
79 |
+
self.video_special_tokens = {
|
80 |
+
"<|begin_of_video|>": special_token_offset,
|
81 |
+
"<|end_of_video|>": special_token_offset + 1,
|
82 |
+
"<|pad_token_video|>": special_token_offset + 2,
|
83 |
+
}
|
84 |
+
|
85 |
+
self.vocab_size = update_vocab_size(
|
86 |
+
existing_vocab_size=self.vocab_size,
|
87 |
+
to_be_added_vocab_size=self.tokenizer_config.video_tokenizer.vocab_size,
|
88 |
+
training_type=self.training_type,
|
89 |
+
add_special_tokens=self.tokenizer_config.add_special_tokens,
|
90 |
+
video_special_tokens=self.video_special_tokens,
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
self.video_tokenizer = None
|
94 |
+
|
95 |
+
@property
|
96 |
+
def pad_id(self):
|
97 |
+
r"""Returns the pad_id."""
|
98 |
+
|
99 |
+
if self.training_type == "text_only" or self.training_type == "image_text_interleaved":
|
100 |
+
pad_id = self.text_tokenizer.pad_id
|
101 |
+
elif self.training_type in ["text_to_video", "video_to_video"]:
|
102 |
+
pad_id = self.video_special_tokens["<|pad_token_video|>"]
|
103 |
+
else:
|
104 |
+
raise ValueError(f"training_type {self.training_type} not defined")
|
105 |
+
return pad_id
|
106 |
+
|
107 |
+
@property
|
108 |
+
def ignore_index(self):
|
109 |
+
r"""Returns which token should be ignored during loss computation."""
|
110 |
+
if self.training_type == "text_only" or self.training_type == "image_text_interleaved":
|
111 |
+
if self.text_tokenizer.pad_id == self.text_tokenizer.eos_id:
|
112 |
+
# If the PAD token is the same as the EOS token, we do not ignore it during loss
|
113 |
+
# computation, since we want the model to be able to predict EOS tokens in inference.
|
114 |
+
# The PyTorch default ignore_index for the cross-entropy loss is -100.
|
115 |
+
ignore_index = -100
|
116 |
+
else:
|
117 |
+
ignore_index = self.text_tokenizer.pad_id
|
118 |
+
elif self.training_type in ["text_to_video", "video_to_video"]:
|
119 |
+
ignore_index = self.pad_id
|
120 |
+
else:
|
121 |
+
raise ValueError(f"training_type {self.training_type} not defined")
|
122 |
+
return ignore_index
|
123 |
+
|
124 |
+
@property
|
125 |
+
def stop_tokens(self):
|
126 |
+
r"""Returns the stop tokens."""
|
127 |
+
if self.training_type == "text_only" or self.training_type == "image_text_interleaved":
|
128 |
+
stop_tokens = self.text_tokenizer.stop_tokens
|
129 |
+
elif self.training_type in ["text_to_video", "video_to_video"]:
|
130 |
+
stop_tokens = set([self.video_special_tokens["<|end_of_video|>"]])
|
131 |
+
else:
|
132 |
+
raise ValueError(f"training_type {self.training_type} not defined")
|
133 |
+
return stop_tokens
|
134 |
+
|
135 |
+
def _tokenize_text(self, raw_text: list[str], max_text_seq_len: int = -1):
|
136 |
+
r"""Function to tokenize text.
|
137 |
+
Args:
|
138 |
+
raw_text (list[str]): List of input strings
|
139 |
+
max_text_seq_len (int): Maximum sequence length returned by text tokenizer
|
140 |
+
Returns:
|
141 |
+
text_tokens (list[list[int]]): List of text tokens
|
142 |
+
"""
|
143 |
+
|
144 |
+
batch_size = len(raw_text)
|
145 |
+
text_tokens = [self.text_tokenizer.encode(raw_text[i], bos=True, eos=True) for i in range(batch_size)]
|
146 |
+
|
147 |
+
# Clipping the text tokens so that the sequence length does not exceed max_text_seq_len
|
148 |
+
if max_text_seq_len > -1:
|
149 |
+
for i in range(len(text_tokens)):
|
150 |
+
if len(text_tokens[i]) > max_text_seq_len:
|
151 |
+
# Simply clip and add end of seq token
|
152 |
+
text_tokens[i] = text_tokens[i][0 : max_text_seq_len - 1] + [self.text_tokenizer.eos_id]
|
153 |
+
return text_tokens
|
154 |
+
|
155 |
+
def _tokenize_class(self, cls_labels: list[str]):
|
156 |
+
r"""Function to tokenize the class label.
|
157 |
+
Args:
|
158 |
+
cls_labels (list[str]): List of class indices
|
159 |
+
Returns:
|
160 |
+
class_tokens (list[list[int]]): List of class tokens
|
161 |
+
"""
|
162 |
+
|
163 |
+
# tokenizer_offset tells what offset should be added to the tokens.
|
164 |
+
# This is needed for vocab expansion.
|
165 |
+
class_tokens = [[int(x) + self.tokenizer_config.class_tokenizer.tokenizer_offset] for x in cls_labels]
|
166 |
+
|
167 |
+
return class_tokens
|
168 |
+
|
169 |
+
def _tokenize_video(self, videos: torch.Tensor, pixel_chunk_duration: Optional[int] = None):
|
170 |
+
r"""Function to tokenize video.
|
171 |
+
Args:
|
172 |
+
videos (torch.Tensor): Input video data tensor
|
173 |
+
pixel_chunk_duration (Optional[float]): Pixel chunk duration. If provided, we pass it to the video tokenizer.
|
174 |
+
Returns:
|
175 |
+
video_tokens (list[list[int]]): List of video tokens
|
176 |
+
"""
|
177 |
+
|
178 |
+
video_tokens = []
|
179 |
+
batch_size = videos.shape[0]
|
180 |
+
|
181 |
+
quantized_out, _ = self.video_tokenizer.encode(videos, pixel_chunk_duration=pixel_chunk_duration)
|
182 |
+
indices = self.video_tokenizer.fsq_quantizer.codes_to_indices(quantized_out.permute(0, 2, 3, 4, 1))
|
183 |
+
|
184 |
+
# Flatten the indices
|
185 |
+
indices = rearrange(indices, "B T H W -> B (T H W)")
|
186 |
+
|
187 |
+
# tokenizer_offset tells what offset should be added to the tokens.
|
188 |
+
# This is needed for vocab expansion.
|
189 |
+
indices += self.tokenizer_config.video_tokenizer.tokenizer_offset
|
190 |
+
|
191 |
+
# Add begin and end of video tokens
|
192 |
+
bov_token = self.video_special_tokens["<|begin_of_video|>"]
|
193 |
+
eov_token = self.video_special_tokens["<|end_of_video|>"]
|
194 |
+
|
195 |
+
# Append bov and eov tokens
|
196 |
+
if self.tokenizer_config.add_special_tokens:
|
197 |
+
for i in range(batch_size):
|
198 |
+
video_tokens.append([bov_token] + indices[i].tolist() + [eov_token])
|
199 |
+
else:
|
200 |
+
if self.training_type == "text_to_video":
|
201 |
+
for i in range(batch_size):
|
202 |
+
video_tokens.append([bov_token] + indices[i].tolist())
|
203 |
+
else:
|
204 |
+
for i in range(batch_size):
|
205 |
+
video_tokens.append(indices[i].tolist())
|
206 |
+
assert (
|
207 |
+
len(video_tokens[-1]) == self.tokenizer_config.video_tokenizer.max_seq_len
|
208 |
+
), f"Expected {self.tokenizer_config.video_tokenizer.max_seq_len} tokens, got {len(video_tokens[-1])}; video shape: {videos.shape}"
|
209 |
+
|
210 |
+
return video_tokens
|
211 |
+
|
212 |
+
def tokenize(self, data_batch: dict):
|
213 |
+
r"""Function to tokenize data_dict.
|
214 |
+
Args:
|
215 |
+
data_batch (dict): Input data dict
|
216 |
+
Returns:
|
217 |
+
tokens (torch.LongTensor): Token tensor dict
|
218 |
+
"""
|
219 |
+
|
220 |
+
if (
|
221 |
+
self.training_type in ["text_only", "image_text_interleaved"]
|
222 |
+
and not self.tokenizer_config.text_tokenizer.tokenize_here
|
223 |
+
):
|
224 |
+
# In case of pre-computed tokens, just return the data_batch
|
225 |
+
return data_batch["tokens"], None
|
226 |
+
|
227 |
+
# Online tokenization
|
228 |
+
tokens = []
|
229 |
+
token_boundaries = defaultdict(list)
|
230 |
+
|
231 |
+
# Obtain maximum sequence length
|
232 |
+
max_text_seq_len = -1
|
233 |
+
max_visual_seq_len = -1
|
234 |
+
|
235 |
+
if self.training_type in ["text_to_video", "video_to_video"]:
|
236 |
+
max_visual_seq_len = self.tokenizer_config.video_tokenizer.max_seq_len
|
237 |
+
|
238 |
+
# If max visual sequence length is specified, make sure that text is clipped so that
|
239 |
+
# the full video/image is always seen.
|
240 |
+
if max_visual_seq_len > -1:
|
241 |
+
if self.tokenizer_config.add_special_tokens:
|
242 |
+
max_visual_seq_len = max_visual_seq_len + 2 # Two special tokens is for [bov, eov] or [boi, eoi] token
|
243 |
+
elif self.training_type == "text_to_video":
|
244 |
+
max_visual_seq_len = max_visual_seq_len + 1
|
245 |
+
else:
|
246 |
+
max_visual_seq_len = max_visual_seq_len
|
247 |
+
assert (
|
248 |
+
max_visual_seq_len <= self.total_seq_len
|
249 |
+
), f"max_visual_seq_len ({max_visual_seq_len}) is greater that total sequence length ({self.total_seq_len})"
|
250 |
+
max_text_seq_len = self.total_seq_len - max_visual_seq_len
|
251 |
+
|
252 |
+
# Tokenize the text
|
253 |
+
if (
|
254 |
+
"text" in self.training_type
|
255 |
+
and self.text_tokenizer is not None
|
256 |
+
and self.tokenizer_config.text_tokenizer.tokenize_here
|
257 |
+
):
|
258 |
+
key = self.tokenizer_config.text_tokenizer.data_key
|
259 |
+
batch_size = len(data_batch[key])
|
260 |
+
assert key in data_batch, f"Key {key} should be present in data for text tokenizer"
|
261 |
+
tokens = self._tokenize_text(data_batch["caption"], max_text_seq_len)
|
262 |
+
|
263 |
+
for i in range(batch_size):
|
264 |
+
token_boundaries["text"].append((0, len(tokens[i])))
|
265 |
+
else:
|
266 |
+
tokens = []
|
267 |
+
batch_size = None
|
268 |
+
|
269 |
+
# Tokenize the class label
|
270 |
+
if "class" in self.training_type and self.tokenizer_config.class_tokenizer is not None:
|
271 |
+
key = self.tokenizer_config.class_tokenizer.data_key
|
272 |
+
assert key in data_batch, f"Key {key} should be present in data for class tokenizer"
|
273 |
+
batch_size = len(data_batch[key]) if batch_size is None else batch_size
|
274 |
+
tokens_class = self._tokenize_class(data_batch[key])
|
275 |
+
if len(tokens) == 0:
|
276 |
+
tokens = tokens_class
|
277 |
+
for i in range(batch_size):
|
278 |
+
token_boundaries["class"].append((0, len(tokens[i])))
|
279 |
+
else:
|
280 |
+
for i in range(batch_size):
|
281 |
+
token_boundaries["class"].append((len(tokens[i]), len(tokens[i]) + len(tokens_class[i])))
|
282 |
+
tokens[i] = tokens[i] + tokens_class[i]
|
283 |
+
|
284 |
+
# Tokenize the video
|
285 |
+
if self.video_tokenizer is not None and self.tokenizer_config.video_tokenizer.tokenize_here:
|
286 |
+
key = self.tokenizer_config.video_tokenizer.data_key
|
287 |
+
assert key in data_batch, f"Key {key} should be present in data for video tokenizer"
|
288 |
+
batch_size = len(data_batch[key]) if batch_size is None else batch_size
|
289 |
+
|
290 |
+
pixel_chunk_duration = (
|
291 |
+
None # If not specified, we assume it's a video dataset and use the default chunk duration
|
292 |
+
)
|
293 |
+
dataset_name = data_batch.get("dataset_name", None)
|
294 |
+
if dataset_name is not None and dataset_name.startswith("image"):
|
295 |
+
# If it's an image dataset, we use a pixel chunk duration of 1
|
296 |
+
pixel_chunk_duration = 1
|
297 |
+
tokens_video = self._tokenize_video(data_batch[key], pixel_chunk_duration=pixel_chunk_duration)
|
298 |
+
if len(tokens) == 0:
|
299 |
+
tokens = tokens_video
|
300 |
+
for i in range(batch_size):
|
301 |
+
token_boundaries["video"].append((0, len(tokens[i])))
|
302 |
+
# [B,] each entry is ((0, len(tokens[i])))
|
303 |
+
else:
|
304 |
+
for i in range(batch_size):
|
305 |
+
token_boundaries["video"].append((len(tokens[i]), len(tokens[i]) + len(tokens_video[i])))
|
306 |
+
tokens[i] = tokens[i] + tokens_video[i]
|
307 |
+
|
308 |
+
# Combine the tokens and do padding
|
309 |
+
max_seq_len_in_batch = max([len(token) for token in tokens])
|
310 |
+
if self.pad_to_multiple_of is not None:
|
311 |
+
# Pad the sequence length to the nearest multiple of pad_to_multiple_of
|
312 |
+
max_seq_len_in_batch = ((max_seq_len_in_batch - 1) // self.pad_to_multiple_of + 1) * self.pad_to_multiple_of
|
313 |
+
pad_to_len = min(max_seq_len_in_batch, self.total_seq_len)
|
314 |
+
for i in range(len(tokens)):
|
315 |
+
if len(tokens[i]) < pad_to_len:
|
316 |
+
tokens[i] = tokens[i] + [self.pad_id] * (pad_to_len - len(tokens[i]))
|
317 |
+
else:
|
318 |
+
tokens[i] = tokens[i][0:pad_to_len]
|
319 |
+
|
320 |
+
# Convert it to long tensor
|
321 |
+
tokens = torch.LongTensor(tokens)
|
322 |
+
return tokens, token_boundaries
|
ar_tokenizer_image_text_tokenizer.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Any, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import transformers
|
21 |
+
from transformers import AutoImageProcessor
|
22 |
+
from transformers.image_utils import ImageInput, is_valid_image, load_image
|
23 |
+
|
24 |
+
from .ar_tokenizer_text_tokenizer import TextTokenizer
|
25 |
+
from .log import log
|
26 |
+
|
27 |
+
# Configuration for different vision-language models
|
28 |
+
IMAGE_CONFIGS = {
|
29 |
+
"pixtral": {
|
30 |
+
"patch_size": 16,
|
31 |
+
"image_token": "[IMG]",
|
32 |
+
"image_break_token": "[IMG_BREAK]",
|
33 |
+
"image_end_token": "[IMG_END]",
|
34 |
+
}
|
35 |
+
}
|
36 |
+
|
37 |
+
# Chat template for Pixtral-12B-Instruct
|
38 |
+
PIXTRAL_CHAT_TEMPLATE = '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["content"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {{- message["content"] + eos_token}}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}'
|
39 |
+
|
40 |
+
|
41 |
+
# Copied from transformers.models.pixtral.processing_pixtral.is_url
|
42 |
+
def is_url(val) -> bool:
|
43 |
+
"""Check if the given value is a URL."""
|
44 |
+
return isinstance(val, str) and val.startswith("http")
|
45 |
+
|
46 |
+
|
47 |
+
# Copied from transformers.models.pixtral.processing_pixtral.is_image_or_image_url
|
48 |
+
def is_image_or_image_url(elem):
|
49 |
+
"""Check if the given element is an image or an image URL."""
|
50 |
+
return is_url(elem) or is_valid_image(elem)
|
51 |
+
|
52 |
+
|
53 |
+
def load_image_list(
|
54 |
+
image_list: List[Union[str, "PIL.Image.Image"]], timeout: Optional[float] = None
|
55 |
+
) -> List["PIL.Image.Image"]:
|
56 |
+
"""
|
57 |
+
Load a list of images.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
image_list (List[Union[str, PIL.Image.Image]]): The list of images to load.
|
61 |
+
timeout (Optional[float]): The timeout for loading the image.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
List[PIL.Image.Image]: The list of loaded images.
|
65 |
+
"""
|
66 |
+
return [load_image(image, timeout=timeout) for image in image_list]
|
67 |
+
|
68 |
+
|
69 |
+
class ImageTextTokenizer(TextTokenizer):
|
70 |
+
"""
|
71 |
+
Image-text tokenizer class that extends the text tokenizer to support vision tokens as well.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
model_family: str,
|
77 |
+
is_instruct_model: bool,
|
78 |
+
tokenizer_path: str,
|
79 |
+
image_processor_path: str,
|
80 |
+
):
|
81 |
+
"""
|
82 |
+
Initialize the ImageTextTokenizer.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
model_family (str): The model family.
|
86 |
+
is_instruct_model (bool): Whether the model is an instruct model.
|
87 |
+
s3_credential_path (str): The path to the s3 credential file. Defaults to "credentials/pbss_dir.secret".
|
88 |
+
|
89 |
+
Raises:
|
90 |
+
AssertionError: If the model family is not supported or if the transformers version is incompatible.
|
91 |
+
"""
|
92 |
+
super().__init__(
|
93 |
+
model_family=model_family,
|
94 |
+
is_instruct_model=is_instruct_model,
|
95 |
+
local_path=tokenizer_path,
|
96 |
+
)
|
97 |
+
assert model_family in ["pixtral"], f"Unsupported model family: {model_family}"
|
98 |
+
if model_family == "pixtral":
|
99 |
+
# Need transformers>=4.45.0
|
100 |
+
assert transformers.__version__ >= "4.45.0", "Pixtral requires transformers>=4.45.0"
|
101 |
+
assert is_instruct_model, "Pixtral requires is_instruct_model=True"
|
102 |
+
if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None:
|
103 |
+
setattr(self.tokenizer, "chat_template", PIXTRAL_CHAT_TEMPLATE)
|
104 |
+
log.debug(f"Pixtral tokenizer chat template set to: {PIXTRAL_CHAT_TEMPLATE}")
|
105 |
+
|
106 |
+
# Set up image-specific configurations
|
107 |
+
image_config = IMAGE_CONFIGS[model_family]
|
108 |
+
self.patch_size = image_config["patch_size"]
|
109 |
+
self.image_token = image_config["image_token"]
|
110 |
+
self.image_break_token = image_config["image_break_token"]
|
111 |
+
self.image_end_token = image_config["image_end_token"]
|
112 |
+
|
113 |
+
# Initialize the image processor
|
114 |
+
self.image_processor = AutoImageProcessor.from_pretrained(image_processor_path)
|
115 |
+
|
116 |
+
def encode(
|
117 |
+
self,
|
118 |
+
text: Union[str, List[str], List[int]],
|
119 |
+
*, # Enforce keyword-only arguments
|
120 |
+
images: Optional[ImageInput] = None,
|
121 |
+
image_kwargs: Optional[Dict[str, Any]] = None,
|
122 |
+
**text_kwargs,
|
123 |
+
) -> List[int]:
|
124 |
+
"""
|
125 |
+
Process the images and return the tokenized images and text.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
129 |
+
The sequence or batch of sequences to be encoded.
|
130 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
131 |
+
The image or batch of images to be prepared.
|
132 |
+
image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing.
|
133 |
+
**text_kwargs: Additional keyword arguments for text processing.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
A dictionary with the following fields:
|
137 |
+
- **input_ids** -- List of token ids to be fed to a model.
|
138 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
139 |
+
- **pixel_values** -- Pixel values to be fed to a model.
|
140 |
+
|
141 |
+
Raises:
|
142 |
+
ValueError: If the input images are in an invalid format.
|
143 |
+
"""
|
144 |
+
|
145 |
+
output_dict, image_inputs = {}, {}
|
146 |
+
if images is not None:
|
147 |
+
# Preprocess images
|
148 |
+
if is_image_or_image_url(images):
|
149 |
+
images = [[images]]
|
150 |
+
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
151 |
+
images = [images]
|
152 |
+
elif (
|
153 |
+
not isinstance(images, list)
|
154 |
+
and not isinstance(images[0], list)
|
155 |
+
and not is_image_or_image_url(images[0][0])
|
156 |
+
):
|
157 |
+
raise ValueError(
|
158 |
+
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
|
159 |
+
)
|
160 |
+
|
161 |
+
# Load and process images
|
162 |
+
images = [load_image_list(sample) for sample in images]
|
163 |
+
image_kwargs = image_kwargs or {}
|
164 |
+
image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors="np", **image_kwargs)
|
165 |
+
|
166 |
+
# Validate image inputs
|
167 |
+
assert "pixel_values" in image_inputs, "pixel_values not found in image_inputs"
|
168 |
+
assert "image_sizes" in image_inputs, "image_sizes not found in image_inputs"
|
169 |
+
assert len(image_inputs.keys()) == 2, "Only one key is allowed in image_inputs, got {}".format(
|
170 |
+
image_inputs.keys()
|
171 |
+
)
|
172 |
+
|
173 |
+
# Extract pixel values and image sizes
|
174 |
+
pixel_values = image_inputs["pixel_values"][0]
|
175 |
+
image_sizes = image_inputs["image_sizes"][0]
|
176 |
+
unique_sizes = np.unique(image_sizes, axis=0)
|
177 |
+
|
178 |
+
assert len(unique_sizes) == 1, "All images must have the same size, got {}".format(unique_sizes)
|
179 |
+
|
180 |
+
# Convert pixel values to PyTorch tensor
|
181 |
+
pixel_values = np.asarray(pixel_values)
|
182 |
+
pixel_values = torch.from_numpy(pixel_values)
|
183 |
+
output_dict["pixel_values"] = pixel_values
|
184 |
+
output_dict["image_sizes"] = image_sizes
|
185 |
+
|
186 |
+
# Expand image tokens in text
|
187 |
+
if image_inputs.get("pixel_values") is not None:
|
188 |
+
replace_strings = []
|
189 |
+
# Calculate the number of tokens needed for each image and create a placeholder
|
190 |
+
for image_size in image_sizes:
|
191 |
+
height, width = image_size
|
192 |
+
num_height_tokens = height // self.patch_size
|
193 |
+
num_width_tokens = width // self.patch_size
|
194 |
+
replace_tokens = [[self.image_token] * num_width_tokens + [self.image_break_token]] * num_height_tokens
|
195 |
+
# Flatten list
|
196 |
+
replace_tokens = [item for sublist in replace_tokens for item in sublist]
|
197 |
+
replace_tokens[-1] = self.image_end_token
|
198 |
+
replace_str = "".join(replace_tokens)
|
199 |
+
replace_strings.append(replace_str)
|
200 |
+
text = text.replace(self.image_token, "<placeholder>", 1)
|
201 |
+
|
202 |
+
# Replace placeholders with actual image token sequences
|
203 |
+
while "<placeholder>" in text:
|
204 |
+
replace_str = replace_strings.pop(0)
|
205 |
+
text = text.replace("<placeholder>", replace_str, 1)
|
206 |
+
|
207 |
+
# Encode the text
|
208 |
+
text_inputs = super(ImageTextTokenizer, self).encode(text, **text_kwargs)
|
209 |
+
|
210 |
+
output_dict["input_ids"] = text_inputs
|
211 |
+
return output_dict
|
212 |
+
|
213 |
+
def apply_chat_template(
|
214 |
+
self,
|
215 |
+
conversation: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
|
216 |
+
*,
|
217 |
+
images: Optional[ImageInput] = None,
|
218 |
+
image_kwargs: Optional[Dict[str, Any]] = None,
|
219 |
+
add_generation_prompt: bool = False,
|
220 |
+
tokenize: bool = True,
|
221 |
+
padding: bool = False,
|
222 |
+
truncation: bool = False,
|
223 |
+
max_length: Optional[int] = None,
|
224 |
+
return_tensors: Optional[str] = None,
|
225 |
+
return_dict: bool = True,
|
226 |
+
return_assistant_tokens_mask: bool = False,
|
227 |
+
generation_prefix: str = "",
|
228 |
+
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
229 |
+
**kwargs,
|
230 |
+
):
|
231 |
+
"""
|
232 |
+
Apply the chat template to the conversation.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
conversation (List[Dict[str, Any]] | List[List[Dict[str, Any]]]): The conversation to process.
|
236 |
+
images (Optional[ImageInput]): Images to include in the conversation.
|
237 |
+
image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing.
|
238 |
+
add_generation_prompt (bool): Whether to add a generation prompt.
|
239 |
+
tokenize (bool): Whether to tokenize the output.
|
240 |
+
padding (bool): Whether to pad the output.
|
241 |
+
truncation (bool): Whether to truncate the output.
|
242 |
+
max_length (Optional[int]): Maximum length of the output.
|
243 |
+
return_tensors (Optional[str]): The type of tensors to return.
|
244 |
+
return_dict (bool): Whether to return a dictionary.
|
245 |
+
return_assistant_tokens_mask (bool): Whether to return the assistant tokens mask.
|
246 |
+
generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "".
|
247 |
+
tokenizer_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer.
|
248 |
+
**kwargs: Additional keyword arguments.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
The processed conversation with applied chat template.
|
252 |
+
|
253 |
+
Raises:
|
254 |
+
AssertionError: If return_dict is False or if the conversation format is invalid.
|
255 |
+
"""
|
256 |
+
assert return_dict, "return_dict must be True for ImageTextTokenizer"
|
257 |
+
assert isinstance(conversation, list), "conversation must be a list"
|
258 |
+
if isinstance(conversation[0], list):
|
259 |
+
assert len(conversation) == 1, "Only support single-conversation input, got {}".format(conversation)
|
260 |
+
conversation = conversation[0]
|
261 |
+
|
262 |
+
# Extract images from the conversation if not provided
|
263 |
+
if images is None:
|
264 |
+
images = []
|
265 |
+
for msg in conversation:
|
266 |
+
if msg.get("images", None) is not None:
|
267 |
+
images = images + (msg["images"])
|
268 |
+
images = load_image_list(images)
|
269 |
+
# In case the input does not have images, will ignore
|
270 |
+
# Useful in feeding VLM inputs with and without images
|
271 |
+
if isinstance(images, list) and len(images) == 0:
|
272 |
+
images = None
|
273 |
+
|
274 |
+
# Apply the chat template to the text
|
275 |
+
text = super().apply_chat_template(
|
276 |
+
conversation,
|
277 |
+
tokenize=False,
|
278 |
+
add_generation_prompt=add_generation_prompt,
|
279 |
+
padding=padding,
|
280 |
+
truncation=truncation,
|
281 |
+
max_length=max_length,
|
282 |
+
return_tensors=return_tensors,
|
283 |
+
return_dict=False,
|
284 |
+
return_assistant_tokens_mask=return_assistant_tokens_mask,
|
285 |
+
generation_prefix=generation_prefix,
|
286 |
+
tokenizer_kwargs=tokenizer_kwargs,
|
287 |
+
**kwargs,
|
288 |
+
)
|
289 |
+
|
290 |
+
if tokenizer_kwargs is None:
|
291 |
+
tokenizer_kwargs = {}
|
292 |
+
|
293 |
+
# Encode the text and images
|
294 |
+
output = self.encode(
|
295 |
+
text,
|
296 |
+
images=images,
|
297 |
+
image_kwargs=image_kwargs,
|
298 |
+
tokenize=tokenize,
|
299 |
+
padding=padding,
|
300 |
+
truncation=truncation,
|
301 |
+
max_length=max_length,
|
302 |
+
add_special_tokens=False,
|
303 |
+
return_tensors=return_tensors,
|
304 |
+
**tokenizer_kwargs,
|
305 |
+
)
|
306 |
+
return output
|
307 |
+
|
308 |
+
@property
|
309 |
+
def model_input_names(self):
|
310 |
+
"""
|
311 |
+
Get the combined model input names from both the text tokenizer and image processor.
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
List[str]: A list of unique input names.
|
315 |
+
"""
|
316 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
317 |
+
image_processor_input_names = self.image_processor.model_input_names
|
318 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
ar_tokenizer_modules.py
ADDED
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""The model definition for 3D layers
|
17 |
+
|
18 |
+
Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/
|
19 |
+
magvit2_pytorch/magvit2_pytorch.py#L889
|
20 |
+
|
21 |
+
[MIT License Copyright (c) 2023 Phil Wang]
|
22 |
+
https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/LICENSE
|
23 |
+
"""
|
24 |
+
import math
|
25 |
+
from typing import Tuple, Union
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
import torch.nn as nn
|
30 |
+
import torch.nn.functional as F
|
31 |
+
|
32 |
+
from .ar_tokenizer_patching import Patcher3D, UnPatcher3D
|
33 |
+
from .ar_tokenizer_utils import (
|
34 |
+
CausalNormalize,
|
35 |
+
batch2space,
|
36 |
+
batch2time,
|
37 |
+
cast_tuple,
|
38 |
+
is_odd,
|
39 |
+
nonlinearity,
|
40 |
+
replication_pad,
|
41 |
+
space2batch,
|
42 |
+
time2batch,
|
43 |
+
)
|
44 |
+
from .log import log
|
45 |
+
|
46 |
+
|
47 |
+
class CausalConv3d(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
chan_in: int = 1,
|
51 |
+
chan_out: int = 1,
|
52 |
+
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
53 |
+
pad_mode: str = "constant",
|
54 |
+
**kwargs,
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
kernel_size = cast_tuple(kernel_size, 3)
|
58 |
+
|
59 |
+
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
60 |
+
|
61 |
+
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
|
62 |
+
|
63 |
+
dilation = kwargs.pop("dilation", 1)
|
64 |
+
stride = kwargs.pop("stride", 1)
|
65 |
+
time_stride = kwargs.pop("time_stride", 1)
|
66 |
+
time_dilation = kwargs.pop("time_dilation", 1)
|
67 |
+
padding = kwargs.pop("padding", 1)
|
68 |
+
|
69 |
+
self.pad_mode = pad_mode
|
70 |
+
time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride)
|
71 |
+
self.time_pad = time_pad
|
72 |
+
|
73 |
+
self.spatial_pad = (padding, padding, padding, padding)
|
74 |
+
|
75 |
+
stride = (time_stride, stride, stride)
|
76 |
+
dilation = (time_dilation, dilation, dilation)
|
77 |
+
self.conv3d = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
78 |
+
|
79 |
+
def _replication_pad(self, x: torch.Tensor) -> torch.Tensor:
|
80 |
+
x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1)
|
81 |
+
x = torch.cat([x_prev, x], dim=2)
|
82 |
+
padding = self.spatial_pad + (0, 0)
|
83 |
+
return F.pad(x, padding, mode=self.pad_mode, value=0.0)
|
84 |
+
|
85 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
86 |
+
x = self._replication_pad(x)
|
87 |
+
return self.conv3d(x)
|
88 |
+
|
89 |
+
|
90 |
+
class CausalHybridUpsample3d(nn.Module):
|
91 |
+
def __init__(self, in_channels: int, spatial_up: bool = True, temporal_up: bool = True, **ignore_kwargs) -> None:
|
92 |
+
super().__init__()
|
93 |
+
self.conv1 = (
|
94 |
+
CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=1, padding=0)
|
95 |
+
if temporal_up
|
96 |
+
else nn.Identity()
|
97 |
+
)
|
98 |
+
self.conv2 = (
|
99 |
+
CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=1, time_stride=1, padding=1)
|
100 |
+
if spatial_up
|
101 |
+
else nn.Identity()
|
102 |
+
)
|
103 |
+
self.conv3 = (
|
104 |
+
CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0)
|
105 |
+
if spatial_up or temporal_up
|
106 |
+
else nn.Identity()
|
107 |
+
)
|
108 |
+
self.spatial_up = spatial_up
|
109 |
+
self.temporal_up = temporal_up
|
110 |
+
|
111 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
112 |
+
if not self.spatial_up and not self.temporal_up:
|
113 |
+
return x
|
114 |
+
|
115 |
+
# hybrid upsample temporally.
|
116 |
+
if self.temporal_up:
|
117 |
+
time_factor = 1.0 + 1.0 * (x.shape[2] > 1)
|
118 |
+
if isinstance(time_factor, torch.Tensor):
|
119 |
+
time_factor = time_factor.item()
|
120 |
+
x = x.repeat_interleave(int(time_factor), dim=2)
|
121 |
+
x = x[..., int(time_factor - 1) :, :, :]
|
122 |
+
x = self.conv1(x) + x
|
123 |
+
|
124 |
+
# hybrid upsample spatially.
|
125 |
+
if self.spatial_up:
|
126 |
+
x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4)
|
127 |
+
x = self.conv2(x) + x
|
128 |
+
|
129 |
+
# final 1x1x1 conv.
|
130 |
+
x = self.conv3(x)
|
131 |
+
return x
|
132 |
+
|
133 |
+
|
134 |
+
class CausalHybridDownsample3d(nn.Module):
|
135 |
+
def __init__(
|
136 |
+
self, in_channels: int, spatial_down: bool = True, temporal_down: bool = True, **ignore_kwargs
|
137 |
+
) -> None:
|
138 |
+
super().__init__()
|
139 |
+
self.conv1 = (
|
140 |
+
CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=2, time_stride=1, padding=0)
|
141 |
+
if spatial_down
|
142 |
+
else nn.Identity()
|
143 |
+
)
|
144 |
+
self.conv2 = (
|
145 |
+
CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=2, padding=0)
|
146 |
+
if temporal_down
|
147 |
+
else nn.Identity()
|
148 |
+
)
|
149 |
+
self.conv3 = (
|
150 |
+
CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0)
|
151 |
+
if spatial_down or temporal_down
|
152 |
+
else nn.Identity()
|
153 |
+
)
|
154 |
+
self.spatial_down = spatial_down
|
155 |
+
self.temporal_down = temporal_down
|
156 |
+
|
157 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
158 |
+
if not self.spatial_down and not self.temporal_down:
|
159 |
+
return x
|
160 |
+
|
161 |
+
# hybrid downsample spatially.
|
162 |
+
if self.spatial_down:
|
163 |
+
pad = (0, 1, 0, 1, 0, 0)
|
164 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
165 |
+
x1 = self.conv1(x)
|
166 |
+
x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
167 |
+
x = x1 + x2
|
168 |
+
|
169 |
+
# hybrid downsample temporally.
|
170 |
+
if self.temporal_down:
|
171 |
+
x = replication_pad(x)
|
172 |
+
x1 = self.conv2(x)
|
173 |
+
x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1))
|
174 |
+
x = x1 + x2
|
175 |
+
|
176 |
+
# final 1x1x1 conv.
|
177 |
+
x = self.conv3(x)
|
178 |
+
return x
|
179 |
+
|
180 |
+
|
181 |
+
class CausalResnetBlockFactorized3d(nn.Module):
|
182 |
+
def __init__(self, *, in_channels: int, out_channels: int = None, dropout: float, num_groups: int) -> None:
|
183 |
+
super().__init__()
|
184 |
+
self.in_channels = in_channels
|
185 |
+
out_channels = in_channels if out_channels is None else out_channels
|
186 |
+
|
187 |
+
self.norm1 = CausalNormalize(in_channels, num_groups=1)
|
188 |
+
self.conv1 = nn.Sequential(
|
189 |
+
CausalConv3d(in_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1),
|
190 |
+
CausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0),
|
191 |
+
)
|
192 |
+
self.norm2 = CausalNormalize(out_channels, num_groups=num_groups)
|
193 |
+
self.dropout = torch.nn.Dropout(dropout)
|
194 |
+
self.conv2 = nn.Sequential(
|
195 |
+
CausalConv3d(out_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1),
|
196 |
+
CausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0),
|
197 |
+
)
|
198 |
+
self.nin_shortcut = (
|
199 |
+
CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
200 |
+
if in_channels != out_channels
|
201 |
+
else nn.Identity()
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
205 |
+
h = x
|
206 |
+
h = self.norm1(h)
|
207 |
+
h = nonlinearity(h)
|
208 |
+
h = self.conv1(h)
|
209 |
+
|
210 |
+
h = self.norm2(h)
|
211 |
+
h = nonlinearity(h)
|
212 |
+
h = self.dropout(h)
|
213 |
+
h = self.conv2(h)
|
214 |
+
x = self.nin_shortcut(x)
|
215 |
+
|
216 |
+
return x + h
|
217 |
+
|
218 |
+
|
219 |
+
class CausalAttnBlock(nn.Module):
|
220 |
+
def __init__(self, in_channels: int, num_groups: int) -> None:
|
221 |
+
super().__init__()
|
222 |
+
|
223 |
+
self.norm = CausalNormalize(in_channels, num_groups=num_groups)
|
224 |
+
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
225 |
+
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
226 |
+
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
227 |
+
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
228 |
+
|
229 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
230 |
+
h_ = x
|
231 |
+
h_ = self.norm(h_)
|
232 |
+
q = self.q(h_)
|
233 |
+
k = self.k(h_)
|
234 |
+
v = self.v(h_)
|
235 |
+
|
236 |
+
# compute attention
|
237 |
+
q, batch_size = time2batch(q)
|
238 |
+
k, batch_size = time2batch(k)
|
239 |
+
v, batch_size = time2batch(v)
|
240 |
+
|
241 |
+
b, c, h, w = q.shape
|
242 |
+
q = q.reshape(b, c, h * w)
|
243 |
+
q = q.permute(0, 2, 1)
|
244 |
+
k = k.reshape(b, c, h * w)
|
245 |
+
w_ = torch.bmm(q, k)
|
246 |
+
w_ = w_ * (int(c) ** (-0.5))
|
247 |
+
w_ = F.softmax(w_, dim=2)
|
248 |
+
|
249 |
+
# attend to values
|
250 |
+
v = v.reshape(b, c, h * w)
|
251 |
+
w_ = w_.permute(0, 2, 1)
|
252 |
+
h_ = torch.bmm(v, w_)
|
253 |
+
h_ = h_.reshape(b, c, h, w)
|
254 |
+
|
255 |
+
h_ = batch2time(h_, batch_size)
|
256 |
+
h_ = self.proj_out(h_)
|
257 |
+
return x + h_
|
258 |
+
|
259 |
+
|
260 |
+
class CausalTemporalAttnBlock(nn.Module):
|
261 |
+
def __init__(self, in_channels: int, num_groups: int) -> None:
|
262 |
+
super().__init__()
|
263 |
+
|
264 |
+
self.norm = CausalNormalize(in_channels, num_groups=num_groups)
|
265 |
+
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
266 |
+
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
267 |
+
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
268 |
+
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
269 |
+
|
270 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
271 |
+
h_ = x
|
272 |
+
h_ = self.norm(h_)
|
273 |
+
q = self.q(h_)
|
274 |
+
k = self.k(h_)
|
275 |
+
v = self.v(h_)
|
276 |
+
|
277 |
+
# compute attention
|
278 |
+
q, batch_size, height = space2batch(q)
|
279 |
+
k, _, _ = space2batch(k)
|
280 |
+
v, _, _ = space2batch(v)
|
281 |
+
|
282 |
+
bhw, c, t = q.shape
|
283 |
+
q = q.permute(0, 2, 1) # (bhw, t, c)
|
284 |
+
k = k.permute(0, 2, 1) # (bhw, t, c)
|
285 |
+
v = v.permute(0, 2, 1) # (bhw, t, c)
|
286 |
+
|
287 |
+
w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t)
|
288 |
+
w_ = w_ * (int(c) ** (-0.5))
|
289 |
+
|
290 |
+
# Apply causal mask
|
291 |
+
mask = torch.tril(torch.ones_like(w_))
|
292 |
+
w_ = w_.masked_fill(mask == 0, float("-inf"))
|
293 |
+
w_ = F.softmax(w_, dim=2)
|
294 |
+
|
295 |
+
# attend to values
|
296 |
+
h_ = torch.bmm(w_, v) # (bhw, t, c)
|
297 |
+
h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t)
|
298 |
+
|
299 |
+
h_ = batch2space(h_, batch_size, height)
|
300 |
+
h_ = self.proj_out(h_)
|
301 |
+
return x + h_
|
302 |
+
|
303 |
+
|
304 |
+
class EncoderFactorized(nn.Module):
|
305 |
+
def __init__(
|
306 |
+
self,
|
307 |
+
in_channels: int,
|
308 |
+
channels: int,
|
309 |
+
channels_mult: list[int],
|
310 |
+
num_res_blocks: int,
|
311 |
+
attn_resolutions: list[int],
|
312 |
+
dropout: float,
|
313 |
+
resolution: int,
|
314 |
+
z_channels: int,
|
315 |
+
spatial_compression: int,
|
316 |
+
temporal_compression: int,
|
317 |
+
**ignore_kwargs,
|
318 |
+
) -> None:
|
319 |
+
super().__init__()
|
320 |
+
self.num_resolutions = len(channels_mult)
|
321 |
+
self.num_res_blocks = num_res_blocks
|
322 |
+
|
323 |
+
# Patcher.
|
324 |
+
patch_size = ignore_kwargs.get("patch_size", 1)
|
325 |
+
self.patcher3d = Patcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
|
326 |
+
in_channels = in_channels * patch_size * patch_size * patch_size
|
327 |
+
|
328 |
+
# calculate the number of downsample operations
|
329 |
+
self.num_spatial_downs = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
|
330 |
+
assert (
|
331 |
+
self.num_spatial_downs <= self.num_resolutions
|
332 |
+
), f"Spatially downsample {self.num_resolutions} times at most"
|
333 |
+
|
334 |
+
self.num_temporal_downs = int(math.log2(temporal_compression)) - int(math.log2(patch_size))
|
335 |
+
assert (
|
336 |
+
self.num_temporal_downs <= self.num_resolutions
|
337 |
+
), f"Temporally downsample {self.num_resolutions} times at most"
|
338 |
+
|
339 |
+
# downsampling
|
340 |
+
self.conv_in = nn.Sequential(
|
341 |
+
CausalConv3d(in_channels, channels, kernel_size=(1, 3, 3), stride=1, padding=1),
|
342 |
+
CausalConv3d(channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0),
|
343 |
+
)
|
344 |
+
|
345 |
+
curr_res = resolution // patch_size
|
346 |
+
in_ch_mult = (1,) + tuple(channels_mult)
|
347 |
+
self.in_ch_mult = in_ch_mult
|
348 |
+
self.down = nn.ModuleList()
|
349 |
+
for i_level in range(self.num_resolutions):
|
350 |
+
block = nn.ModuleList()
|
351 |
+
attn = nn.ModuleList()
|
352 |
+
block_in = channels * in_ch_mult[i_level]
|
353 |
+
block_out = channels * channels_mult[i_level]
|
354 |
+
for _ in range(self.num_res_blocks):
|
355 |
+
block.append(
|
356 |
+
CausalResnetBlockFactorized3d(
|
357 |
+
in_channels=block_in, out_channels=block_out, dropout=dropout, num_groups=1
|
358 |
+
)
|
359 |
+
)
|
360 |
+
block_in = block_out
|
361 |
+
if curr_res in attn_resolutions:
|
362 |
+
attn.append(
|
363 |
+
nn.Sequential(
|
364 |
+
CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1)
|
365 |
+
)
|
366 |
+
)
|
367 |
+
down = nn.Module()
|
368 |
+
down.block = block
|
369 |
+
down.attn = attn
|
370 |
+
if i_level != self.num_resolutions - 1:
|
371 |
+
spatial_down = i_level < self.num_spatial_downs
|
372 |
+
temporal_down = i_level < self.num_temporal_downs
|
373 |
+
down.downsample = CausalHybridDownsample3d(
|
374 |
+
block_in, spatial_down=spatial_down, temporal_down=temporal_down
|
375 |
+
)
|
376 |
+
curr_res = curr_res // 2
|
377 |
+
self.down.append(down)
|
378 |
+
|
379 |
+
# middle
|
380 |
+
self.mid = nn.Module()
|
381 |
+
self.mid.block_1 = CausalResnetBlockFactorized3d(
|
382 |
+
in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1
|
383 |
+
)
|
384 |
+
self.mid.attn_1 = nn.Sequential(
|
385 |
+
CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1)
|
386 |
+
)
|
387 |
+
self.mid.block_2 = CausalResnetBlockFactorized3d(
|
388 |
+
in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1
|
389 |
+
)
|
390 |
+
|
391 |
+
# end
|
392 |
+
self.norm_out = CausalNormalize(block_in, num_groups=1)
|
393 |
+
self.conv_out = nn.Sequential(
|
394 |
+
CausalConv3d(block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1),
|
395 |
+
CausalConv3d(z_channels, z_channels, kernel_size=(3, 1, 1), stride=1, padding=0),
|
396 |
+
)
|
397 |
+
|
398 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
399 |
+
x = self.patcher3d(x)
|
400 |
+
|
401 |
+
# downsampling
|
402 |
+
h = self.conv_in(x)
|
403 |
+
for i_level in range(self.num_resolutions):
|
404 |
+
for i_block in range(self.num_res_blocks):
|
405 |
+
h = self.down[i_level].block[i_block](h)
|
406 |
+
if len(self.down[i_level].attn) > 0:
|
407 |
+
h = self.down[i_level].attn[i_block](h)
|
408 |
+
if i_level != self.num_resolutions - 1:
|
409 |
+
h = self.down[i_level].downsample(h)
|
410 |
+
|
411 |
+
# middle
|
412 |
+
h = self.mid.block_1(h)
|
413 |
+
h = self.mid.attn_1(h)
|
414 |
+
h = self.mid.block_2(h)
|
415 |
+
|
416 |
+
# end
|
417 |
+
h = self.norm_out(h)
|
418 |
+
h = nonlinearity(h)
|
419 |
+
h = self.conv_out(h)
|
420 |
+
return h
|
421 |
+
|
422 |
+
|
423 |
+
class DecoderFactorized(nn.Module):
|
424 |
+
def __init__(
|
425 |
+
self,
|
426 |
+
out_channels: int,
|
427 |
+
channels: int,
|
428 |
+
channels_mult: list[int],
|
429 |
+
num_res_blocks: int,
|
430 |
+
attn_resolutions: list[int],
|
431 |
+
dropout: float,
|
432 |
+
resolution: int,
|
433 |
+
z_channels: int,
|
434 |
+
spatial_compression: int,
|
435 |
+
temporal_compression: int,
|
436 |
+
**ignore_kwargs,
|
437 |
+
):
|
438 |
+
super().__init__()
|
439 |
+
self.num_resolutions = len(channels_mult)
|
440 |
+
self.num_res_blocks = num_res_blocks
|
441 |
+
|
442 |
+
# UnPatcher.
|
443 |
+
patch_size = ignore_kwargs.get("patch_size", 1)
|
444 |
+
self.unpatcher3d = UnPatcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
|
445 |
+
out_ch = out_channels * patch_size * patch_size * patch_size
|
446 |
+
|
447 |
+
# calculate the number of upsample operations
|
448 |
+
self.num_spatial_ups = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
|
449 |
+
assert self.num_spatial_ups <= self.num_resolutions, f"Spatially upsample {self.num_resolutions} times at most"
|
450 |
+
self.num_temporal_ups = int(math.log2(temporal_compression)) - int(math.log2(patch_size))
|
451 |
+
assert (
|
452 |
+
self.num_temporal_ups <= self.num_resolutions
|
453 |
+
), f"Temporally upsample {self.num_resolutions} times at most"
|
454 |
+
|
455 |
+
block_in = channels * channels_mult[self.num_resolutions - 1]
|
456 |
+
curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
|
457 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
458 |
+
log.debug("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
459 |
+
|
460 |
+
# z to block_in
|
461 |
+
self.conv_in = nn.Sequential(
|
462 |
+
CausalConv3d(z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1),
|
463 |
+
CausalConv3d(block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0),
|
464 |
+
)
|
465 |
+
|
466 |
+
# middle
|
467 |
+
self.mid = nn.Module()
|
468 |
+
self.mid.block_1 = CausalResnetBlockFactorized3d(
|
469 |
+
in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1
|
470 |
+
)
|
471 |
+
self.mid.attn_1 = nn.Sequential(
|
472 |
+
CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1)
|
473 |
+
)
|
474 |
+
self.mid.block_2 = CausalResnetBlockFactorized3d(
|
475 |
+
in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1
|
476 |
+
)
|
477 |
+
|
478 |
+
legacy_mode = ignore_kwargs.get("legacy_mode", False)
|
479 |
+
# upsampling
|
480 |
+
self.up = nn.ModuleList()
|
481 |
+
for i_level in reversed(range(self.num_resolutions)):
|
482 |
+
block = nn.ModuleList()
|
483 |
+
attn = nn.ModuleList()
|
484 |
+
block_out = channels * channels_mult[i_level]
|
485 |
+
for _ in range(self.num_res_blocks + 1):
|
486 |
+
block.append(
|
487 |
+
CausalResnetBlockFactorized3d(
|
488 |
+
in_channels=block_in, out_channels=block_out, dropout=dropout, num_groups=1
|
489 |
+
)
|
490 |
+
)
|
491 |
+
block_in = block_out
|
492 |
+
if curr_res in attn_resolutions:
|
493 |
+
attn.append(
|
494 |
+
nn.Sequential(
|
495 |
+
CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1)
|
496 |
+
)
|
497 |
+
)
|
498 |
+
up = nn.Module()
|
499 |
+
up.block = block
|
500 |
+
up.attn = attn
|
501 |
+
if i_level != 0:
|
502 |
+
# The layer index for temporal/spatial downsampling performed in the encoder should correspond
|
503 |
+
# to the layer index, inreverse order, where upsampling is performed in the decoder.
|
504 |
+
# If you've a pre-trained model, you can simply finetune.
|
505 |
+
# For example:
|
506 |
+
# Input tensor = (1, 3, 17, 32, 32)
|
507 |
+
# Patch size = 4 for 3D wavelet transform
|
508 |
+
# Compression rate = (8x16x16)
|
509 |
+
#
|
510 |
+
# We expect successive downsampling in the encoder and upsampling in the decoder to be mirrored.
|
511 |
+
# ENCODER: `(...,5,8,8) -> (...,3,4,4) -> (...,3,2,2)`
|
512 |
+
# DECODER: `(...,3,2,2) -> (...,3,4,4) -> (...,5,8,8)`
|
513 |
+
#
|
514 |
+
# if legacy_mode is True, the temporal upsampling is not perfectly mirrored.
|
515 |
+
# ENCODER: `(...,5,8,8) -> (...,3,4,4) -> (...,3,2,2)`
|
516 |
+
# DECODER: `(...,3,2,2) -> (...,5,4,4) -> (...,5,8,8)`
|
517 |
+
#
|
518 |
+
# Most of the CV and DV tokenizers were trained before 09/01/2024 with upsampling that's not mirrored.
|
519 |
+
# Going forward, new CV/DV tokenizers will adopt `legacy_mode=False`, i.e. use mirrored upsampling.
|
520 |
+
i_level_reverse = self.num_resolutions - i_level - 1
|
521 |
+
if legacy_mode:
|
522 |
+
temporal_up = i_level_reverse < self.num_temporal_ups
|
523 |
+
else:
|
524 |
+
temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1
|
525 |
+
spatial_up = temporal_up or (
|
526 |
+
i_level_reverse < self.num_spatial_ups and self.num_spatial_ups > self.num_temporal_ups
|
527 |
+
)
|
528 |
+
up.upsample = CausalHybridUpsample3d(block_in, spatial_up=spatial_up, temporal_up=temporal_up)
|
529 |
+
curr_res = curr_res * 2
|
530 |
+
self.up.insert(0, up) # prepend to get consistent order
|
531 |
+
|
532 |
+
# end
|
533 |
+
self.norm_out = CausalNormalize(block_in, num_groups=1)
|
534 |
+
self.conv_out = nn.Sequential(
|
535 |
+
CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1),
|
536 |
+
CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0),
|
537 |
+
)
|
538 |
+
|
539 |
+
def forward(self, z):
|
540 |
+
h = self.conv_in(z)
|
541 |
+
|
542 |
+
# middle block.
|
543 |
+
h = self.mid.block_1(h)
|
544 |
+
h = self.mid.attn_1(h)
|
545 |
+
h = self.mid.block_2(h)
|
546 |
+
|
547 |
+
# decoder blocks.
|
548 |
+
for i_level in reversed(range(self.num_resolutions)):
|
549 |
+
for i_block in range(self.num_res_blocks + 1):
|
550 |
+
h = self.up[i_level].block[i_block](h)
|
551 |
+
if len(self.up[i_level].attn) > 0:
|
552 |
+
h = self.up[i_level].attn[i_block](h)
|
553 |
+
if i_level != 0:
|
554 |
+
h = self.up[i_level].upsample(h)
|
555 |
+
|
556 |
+
h = self.norm_out(h)
|
557 |
+
h = nonlinearity(h)
|
558 |
+
h = self.conv_out(h)
|
559 |
+
h = self.unpatcher3d(h)
|
560 |
+
return h
|
ar_tokenizer_patching.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""The patcher and unpatcher implementation for 2D and 3D data."""
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from einops import rearrange
|
21 |
+
|
22 |
+
_WAVELETS = {
|
23 |
+
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
|
24 |
+
"rearrange": torch.tensor([1.0, 1.0]),
|
25 |
+
}
|
26 |
+
_PERSISTENT = False
|
27 |
+
|
28 |
+
|
29 |
+
class Patcher(torch.nn.Module):
|
30 |
+
"""A module to convert image tensors into patches using torch operations.
|
31 |
+
|
32 |
+
The main difference from `class Patching` is that this module implements
|
33 |
+
all operations using torch, rather than python or numpy, for efficiency purpose.
|
34 |
+
|
35 |
+
It's bit-wise identical to the Patching module outputs, with the added
|
36 |
+
benefit of being torch.jit scriptable.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, patch_size=1, patch_method="haar"):
|
40 |
+
super().__init__()
|
41 |
+
self.patch_size = patch_size
|
42 |
+
self.patch_method = patch_method
|
43 |
+
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT)
|
44 |
+
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
45 |
+
self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=_PERSISTENT)
|
46 |
+
for param in self.parameters():
|
47 |
+
param.requires_grad = False
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
if self.patch_method == "haar":
|
51 |
+
return self._haar(x)
|
52 |
+
elif self.patch_method == "rearrange":
|
53 |
+
return self._arrange(x)
|
54 |
+
else:
|
55 |
+
raise ValueError("Unknown patch method: " + self.patch_method)
|
56 |
+
|
57 |
+
def _dwt(self, x, mode="reflect", rescale=False):
|
58 |
+
dtype = x.dtype
|
59 |
+
h = self.wavelets
|
60 |
+
|
61 |
+
n = h.shape[0]
|
62 |
+
g = x.shape[1]
|
63 |
+
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
64 |
+
hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
|
65 |
+
hh = hh.to(dtype=dtype)
|
66 |
+
hl = hl.to(dtype=dtype)
|
67 |
+
|
68 |
+
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
|
69 |
+
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
|
70 |
+
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
|
71 |
+
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
72 |
+
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
73 |
+
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
74 |
+
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
75 |
+
|
76 |
+
out = torch.cat([xll, xlh, xhl, xhh], dim=1)
|
77 |
+
if rescale:
|
78 |
+
out = out / 2
|
79 |
+
return out
|
80 |
+
|
81 |
+
def _haar(self, x):
|
82 |
+
for _ in self.range:
|
83 |
+
x = self._dwt(x, rescale=True)
|
84 |
+
return x
|
85 |
+
|
86 |
+
def _arrange(self, x):
|
87 |
+
x = rearrange(x, "b c (h p1) (w p2) -> b (c p1 p2) h w", p1=self.patch_size, p2=self.patch_size).contiguous()
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
class Patcher3D(Patcher):
|
92 |
+
"""A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
|
93 |
+
|
94 |
+
def __init__(self, patch_size=1, patch_method="haar"):
|
95 |
+
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
96 |
+
self.register_buffer(
|
97 |
+
"patch_size_buffer", patch_size * torch.ones([1], dtype=torch.int32), persistent=_PERSISTENT
|
98 |
+
)
|
99 |
+
|
100 |
+
def _dwt(self, x, mode="reflect", rescale=False):
|
101 |
+
dtype = x.dtype
|
102 |
+
h = self.wavelets
|
103 |
+
|
104 |
+
n = h.shape[0]
|
105 |
+
g = x.shape[1]
|
106 |
+
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
107 |
+
hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
|
108 |
+
hh = hh.to(dtype=dtype)
|
109 |
+
hl = hl.to(dtype=dtype)
|
110 |
+
|
111 |
+
# Handles temporal axis.
|
112 |
+
x = F.pad(x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
|
113 |
+
xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
114 |
+
xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
115 |
+
|
116 |
+
# Handles spatial axes.
|
117 |
+
xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
118 |
+
xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
119 |
+
xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
120 |
+
xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
121 |
+
|
122 |
+
xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
123 |
+
xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
124 |
+
xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
125 |
+
xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
126 |
+
xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
127 |
+
xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
128 |
+
xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
129 |
+
xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
130 |
+
|
131 |
+
out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
|
132 |
+
if rescale:
|
133 |
+
out = out / (2 * torch.sqrt(torch.tensor(2.0)))
|
134 |
+
return out
|
135 |
+
|
136 |
+
def _haar(self, x):
|
137 |
+
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
138 |
+
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
139 |
+
for _ in self.range:
|
140 |
+
x = self._dwt(x, rescale=True)
|
141 |
+
return x
|
142 |
+
|
143 |
+
def _arrange(self, x):
|
144 |
+
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
145 |
+
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
146 |
+
x = rearrange(
|
147 |
+
x,
|
148 |
+
"b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
|
149 |
+
p1=self.patch_size,
|
150 |
+
p2=self.patch_size,
|
151 |
+
p3=self.patch_size,
|
152 |
+
).contiguous()
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
class UnPatcher(torch.nn.Module):
|
157 |
+
"""A module to convert patches into image tensorsusing torch operations.
|
158 |
+
|
159 |
+
The main difference from `class Unpatching` is that this module implements
|
160 |
+
all operations using torch, rather than python or numpy, for efficiency purpose.
|
161 |
+
|
162 |
+
It's bit-wise identical to the Unpatching module outputs, with the added
|
163 |
+
benefit of being torch.jit scriptable.
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(self, patch_size=1, patch_method="haar"):
|
167 |
+
super().__init__()
|
168 |
+
self.patch_size = patch_size
|
169 |
+
self.patch_method = patch_method
|
170 |
+
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT)
|
171 |
+
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
172 |
+
self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=_PERSISTENT)
|
173 |
+
for param in self.parameters():
|
174 |
+
param.requires_grad = False
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
if self.patch_method == "haar":
|
178 |
+
return self._ihaar(x)
|
179 |
+
elif self.patch_method == "rearrange":
|
180 |
+
return self._iarrange(x)
|
181 |
+
else:
|
182 |
+
raise ValueError("Unknown patch method: " + self.patch_method)
|
183 |
+
|
184 |
+
def _idwt(self, x, rescale=False):
|
185 |
+
dtype = x.dtype
|
186 |
+
h = self.wavelets
|
187 |
+
n = h.shape[0]
|
188 |
+
|
189 |
+
g = x.shape[1] // 4
|
190 |
+
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
191 |
+
hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
|
192 |
+
hh = hh.to(dtype=dtype)
|
193 |
+
hl = hl.to(dtype=dtype)
|
194 |
+
|
195 |
+
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
|
196 |
+
|
197 |
+
# Inverse transform.
|
198 |
+
yl = torch.nn.functional.conv_transpose2d(xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
|
199 |
+
yl += torch.nn.functional.conv_transpose2d(xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
|
200 |
+
yh = torch.nn.functional.conv_transpose2d(xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
|
201 |
+
yh += torch.nn.functional.conv_transpose2d(xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
|
202 |
+
y = torch.nn.functional.conv_transpose2d(yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2))
|
203 |
+
y += torch.nn.functional.conv_transpose2d(yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2))
|
204 |
+
|
205 |
+
if rescale:
|
206 |
+
y = y * 2
|
207 |
+
return y
|
208 |
+
|
209 |
+
def _ihaar(self, x):
|
210 |
+
for _ in self.range:
|
211 |
+
x = self._idwt(x, rescale=True)
|
212 |
+
return x
|
213 |
+
|
214 |
+
def _iarrange(self, x):
|
215 |
+
x = rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=self.patch_size, p2=self.patch_size)
|
216 |
+
return x
|
217 |
+
|
218 |
+
|
219 |
+
class UnPatcher3D(UnPatcher):
|
220 |
+
"""A 3D inverse discrete wavelet transform for video wavelet decompositions."""
|
221 |
+
|
222 |
+
def __init__(self, patch_size=1, patch_method="haar"):
|
223 |
+
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
224 |
+
|
225 |
+
def _idwt(self, x, rescale=False):
|
226 |
+
dtype = x.dtype
|
227 |
+
h = self.wavelets
|
228 |
+
|
229 |
+
g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
|
230 |
+
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
231 |
+
hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
|
232 |
+
hl = hl.to(dtype=dtype)
|
233 |
+
hh = hh.to(dtype=dtype)
|
234 |
+
|
235 |
+
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
236 |
+
|
237 |
+
# Height height transposed convolutions.
|
238 |
+
xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
239 |
+
xll += F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
240 |
+
|
241 |
+
xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
242 |
+
xlh += F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
243 |
+
|
244 |
+
xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
245 |
+
xhl += F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
246 |
+
|
247 |
+
xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
248 |
+
xhh += F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
249 |
+
|
250 |
+
# Handles width transposed convolutions.
|
251 |
+
xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
252 |
+
xl += F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
253 |
+
xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
254 |
+
xh += F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
255 |
+
|
256 |
+
# Handles time axis transposed convolutions.
|
257 |
+
x = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
258 |
+
x += F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
259 |
+
|
260 |
+
if rescale:
|
261 |
+
x = x * (2 * torch.sqrt(torch.tensor(2.0)))
|
262 |
+
return x
|
263 |
+
|
264 |
+
def _ihaar(self, x):
|
265 |
+
for _ in self.range:
|
266 |
+
x = self._idwt(x, rescale=True)
|
267 |
+
x = x[:, :, self.patch_size - 1 :, ...]
|
268 |
+
return x
|
269 |
+
|
270 |
+
def _iarrange(self, x):
|
271 |
+
x = rearrange(
|
272 |
+
x,
|
273 |
+
"b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
|
274 |
+
p1=self.patch_size,
|
275 |
+
p2=self.patch_size,
|
276 |
+
p3=self.patch_size,
|
277 |
+
)
|
278 |
+
x = x[:, :, self.patch_size - 1 :, ...]
|
279 |
+
return x
|
ar_tokenizer_quantizers.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Quantizers for discrete image and video tokenization."""
|
17 |
+
|
18 |
+
from typing import Optional
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
from einops import rearrange
|
23 |
+
|
24 |
+
from .ar_tokenizer_utils import default, pack_one, round_ste, unpack_one
|
25 |
+
|
26 |
+
|
27 |
+
class FSQuantizer(nn.Module):
|
28 |
+
"""Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
|
29 |
+
|
30 |
+
Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/
|
31 |
+
vector_quantize_pytorch/finite_scalar_quantization.py
|
32 |
+
[Copyright (c) 2020 Phil Wang]
|
33 |
+
https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
levels: list[int],
|
39 |
+
dim: Optional[int] = None,
|
40 |
+
num_codebooks=1,
|
41 |
+
keep_num_codebooks_dim: Optional[bool] = None,
|
42 |
+
scale: Optional[float] = None,
|
43 |
+
**ignore_kwargs,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.dtype = ignore_kwargs.get("dtype", torch.float32)
|
47 |
+
_levels = torch.tensor(levels, dtype=torch.int32)
|
48 |
+
self.register_buffer("_levels", _levels, persistent=False)
|
49 |
+
|
50 |
+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32)
|
51 |
+
self.register_buffer("_basis", _basis, persistent=False)
|
52 |
+
|
53 |
+
self.scale = scale
|
54 |
+
|
55 |
+
codebook_dim = len(levels)
|
56 |
+
self.codebook_dim = codebook_dim
|
57 |
+
|
58 |
+
effective_codebook_dim = codebook_dim * num_codebooks
|
59 |
+
self.num_codebooks = num_codebooks
|
60 |
+
self.effective_codebook_dim = effective_codebook_dim
|
61 |
+
|
62 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
63 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
64 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
65 |
+
|
66 |
+
self.dim = default(dim, len(_levels) * num_codebooks)
|
67 |
+
|
68 |
+
has_projections = self.dim != effective_codebook_dim
|
69 |
+
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
|
70 |
+
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
|
71 |
+
self.has_projections = has_projections
|
72 |
+
|
73 |
+
self.codebook_size = self._levels.prod().item()
|
74 |
+
|
75 |
+
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
|
76 |
+
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
|
77 |
+
|
78 |
+
def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
|
79 |
+
"""Bound `z`, an array of shape (..., d)."""
|
80 |
+
half_l = (self._levels - 1) * (1 + eps) / 2
|
81 |
+
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
82 |
+
shift = (offset / half_l).atanh()
|
83 |
+
return (z + shift).tanh() * half_l - offset
|
84 |
+
|
85 |
+
def quantize(self, z: torch.Tensor) -> torch.Tensor:
|
86 |
+
"""Quantizes z, returns quantized zhat, same shape as z."""
|
87 |
+
quantized = round_ste(self.bound(z))
|
88 |
+
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
89 |
+
return quantized / half_width
|
90 |
+
|
91 |
+
def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
|
92 |
+
half_width = self._levels // 2
|
93 |
+
return (zhat_normalized * half_width) + half_width
|
94 |
+
|
95 |
+
def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
|
96 |
+
half_width = self._levels // 2
|
97 |
+
return (zhat - half_width) / half_width
|
98 |
+
|
99 |
+
def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
|
100 |
+
"""Converts a `code` to an index in the codebook."""
|
101 |
+
assert zhat.shape[-1] == self.codebook_dim
|
102 |
+
zhat = self._scale_and_shift(zhat).float()
|
103 |
+
return (zhat * self._basis).sum(dim=-1).to(torch.int32)
|
104 |
+
|
105 |
+
def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor:
|
106 |
+
"""Inverse of `codes_to_indices`."""
|
107 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
108 |
+
indices = rearrange(indices, "... -> ... 1")
|
109 |
+
codes_non_centered = (indices // self._basis) % self._levels
|
110 |
+
codes = self._scale_and_shift_inverse(codes_non_centered)
|
111 |
+
|
112 |
+
if self.keep_num_codebooks_dim:
|
113 |
+
codes = rearrange(codes, "... c d -> ... (c d)")
|
114 |
+
|
115 |
+
if project_out:
|
116 |
+
codes = self.project_out(codes)
|
117 |
+
|
118 |
+
if is_img_or_video:
|
119 |
+
codes = rearrange(codes, "b ... d -> b d ...")
|
120 |
+
|
121 |
+
return codes.to(self.dtype)
|
122 |
+
|
123 |
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
124 |
+
"""
|
125 |
+
einstein notation
|
126 |
+
b - batch
|
127 |
+
n - sequence (or flattened spatial dimensions)
|
128 |
+
d - feature dimension, which is also log2(codebook size)
|
129 |
+
c - number of codebook dim
|
130 |
+
"""
|
131 |
+
is_img_or_video = z.ndim >= 4
|
132 |
+
|
133 |
+
# standardize image or video into (batch, seq, dimension)
|
134 |
+
|
135 |
+
if is_img_or_video:
|
136 |
+
z = rearrange(z, "b d ... -> b ... d")
|
137 |
+
z, ps = pack_one(z, "b * d")
|
138 |
+
|
139 |
+
assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
140 |
+
|
141 |
+
z = self.project_in(z)
|
142 |
+
|
143 |
+
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
|
144 |
+
|
145 |
+
codes = self.quantize(z)
|
146 |
+
indices = self.codes_to_indices(codes)
|
147 |
+
|
148 |
+
codes = rearrange(codes, "b n c d -> b n (c d)")
|
149 |
+
|
150 |
+
out = self.project_out(codes)
|
151 |
+
|
152 |
+
# reconstitute image or video dimensions
|
153 |
+
|
154 |
+
if is_img_or_video:
|
155 |
+
out = unpack_one(out, ps, "b * d")
|
156 |
+
out = rearrange(out, "b ... d -> b d ...")
|
157 |
+
indices = unpack_one(indices, ps, "b * c")
|
158 |
+
dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True))
|
159 |
+
else:
|
160 |
+
dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze(1)
|
161 |
+
|
162 |
+
if not self.keep_num_codebooks_dim:
|
163 |
+
indices = rearrange(indices, "... 1 -> ...")
|
164 |
+
|
165 |
+
return (indices, out.to(self.dtype), dummy_loss)
|
ar_tokenizer_text_tokenizer.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Any, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from transformers import AutoTokenizer
|
21 |
+
|
22 |
+
from .log import log
|
23 |
+
|
24 |
+
|
25 |
+
def get_tokenizer_path(model_family: str, is_instruct_model: bool = False):
|
26 |
+
"""
|
27 |
+
Get the tokenizer path from the model family and instruct model flag.
|
28 |
+
Args:
|
29 |
+
model_family (str): The model family.
|
30 |
+
is_instruct_model (bool): Whether the model is an instruct model.
|
31 |
+
Returns:
|
32 |
+
str: The tokenizer path in s3.
|
33 |
+
"""
|
34 |
+
model_family = model_family.lower()
|
35 |
+
if model_family == "mistral":
|
36 |
+
return "mistralai/Mistral-Nemo-Instruct-2407"
|
37 |
+
else:
|
38 |
+
assert model_family in ["llama3", "llama3.1"]
|
39 |
+
if model_family == "llama3":
|
40 |
+
model_path = "meta-llama/Meta-Llama-3-8B"
|
41 |
+
elif model_family == "llama3.1":
|
42 |
+
model_path = "meta-llama/Llama-3.1-8B"
|
43 |
+
else:
|
44 |
+
raise ValueError(f"Unsupported model family: {model_family}")
|
45 |
+
suffix = "-Instruct" if is_instruct_model else ""
|
46 |
+
model_path = f"{model_path}{suffix}"
|
47 |
+
return model_path
|
48 |
+
|
49 |
+
|
50 |
+
class TextTokenizer:
|
51 |
+
"""
|
52 |
+
Text tokenizer class built on HuggingFace's Fast Tokenizer (Rust based).
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
model_family: str,
|
58 |
+
is_instruct_model: bool,
|
59 |
+
local_path: Optional[str] = None,
|
60 |
+
):
|
61 |
+
"""
|
62 |
+
Initialize the TextTokenizer.
|
63 |
+
Args:
|
64 |
+
model_family (str): The model family.
|
65 |
+
is_instruct_model (bool): Whether the model is an instruct model.
|
66 |
+
local_path (Optional[str]): The local path to the tokenizer. If not provided, the tokenizer will be downloaded from the remote path.
|
67 |
+
"""
|
68 |
+
if local_path is None:
|
69 |
+
tokenizer_path = get_tokenizer_path(model_family, is_instruct_model)
|
70 |
+
else:
|
71 |
+
tokenizer_path = local_path
|
72 |
+
|
73 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
74 |
+
self.stop_tokens = {
|
75 |
+
self.tokenizer.eos_token_id,
|
76 |
+
}
|
77 |
+
self.model_family = model_family
|
78 |
+
self.is_instruct_model = is_instruct_model
|
79 |
+
self.eos_id = self.tokenizer.eos_token_id
|
80 |
+
if self.tokenizer.pad_token is None:
|
81 |
+
if model_family.startswith("llama"):
|
82 |
+
self.pad_id = 128004 # "<|finetune_right_pad_id|>"
|
83 |
+
elif model_family == "mistral":
|
84 |
+
self.pad_id = 10 # "<pad>"
|
85 |
+
elif model_family == "pixtral":
|
86 |
+
self.pad_id = 11 # "<pad>"
|
87 |
+
else:
|
88 |
+
raise ValueError(f"pad_id not defined for model_family {model_family}")
|
89 |
+
else:
|
90 |
+
self.pad_id = self.tokenizer.pad_token_id
|
91 |
+
|
92 |
+
def tokenize(self, text: str, *, add_special_tokens: bool = False, **kwargs) -> List[str]:
|
93 |
+
"""
|
94 |
+
Converts a string into a sequence of tokens, replacing unknown tokens with the `unk_token`.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
text (`str`):
|
98 |
+
The sequence to be encoded.
|
99 |
+
add_special_tokens (`bool`, *optional*, defaults to `False`):
|
100 |
+
Whether or not to add the special tokens associated with the corresponding model.
|
101 |
+
Returns:
|
102 |
+
`List[str]`: The list of tokens.
|
103 |
+
"""
|
104 |
+
return self.tokenizer.tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
|
105 |
+
|
106 |
+
def encode(
|
107 |
+
self,
|
108 |
+
text: Union[str, List[str], List[int]],
|
109 |
+
*, # Enforce keyword-only arguments
|
110 |
+
add_special_tokens: bool = True,
|
111 |
+
padding: Union[bool, str] = False,
|
112 |
+
truncation: Union[bool, str] = None,
|
113 |
+
max_length: Optional[int] = None,
|
114 |
+
stride: int = 0,
|
115 |
+
return_tensors: Optional[str] = None,
|
116 |
+
**kwargs,
|
117 |
+
) -> List[int]:
|
118 |
+
"""
|
119 |
+
Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
text (`str`, `List[str]` or `List[int]`):
|
123 |
+
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
|
124 |
+
`tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
|
125 |
+
method).
|
126 |
+
add_special_tokens (`bool`, *optional*, defaults to `True`):
|
127 |
+
Whether or not to add special tokens when encoding the sequences. This will use the underlying
|
128 |
+
`PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are
|
129 |
+
automatically added to the input ids. This is usefull if you want to add `bos` or `eos` tokens
|
130 |
+
automatically.
|
131 |
+
padding (`bool`, `str`, *optional*, defaults to `False`):
|
132 |
+
Activates and controls padding. Accepts the following values:
|
133 |
+
|
134 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
135 |
+
sequence if provided).
|
136 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
137 |
+
acceptable input length for the model if that argument is not provided.
|
138 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
139 |
+
lengths).
|
140 |
+
truncation (`bool`, `str`, *optional*, defaults to `False`):
|
141 |
+
Activates and controls truncation. Accepts the following values:
|
142 |
+
|
143 |
+
- `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
|
144 |
+
to the maximum acceptable input length for the model if that argument is not provided. This will
|
145 |
+
truncate token by token, removing a token from the longest sequence in the pair if a pair of
|
146 |
+
sequences (or a batch of pairs) is provided.
|
147 |
+
- `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
|
148 |
+
maximum acceptable input length for the model if that argument is not provided. This will only
|
149 |
+
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
150 |
+
- `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
|
151 |
+
maximum acceptable input length for the model if that argument is not provided. This will only
|
152 |
+
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
153 |
+
- `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
|
154 |
+
greater than the model maximum admissible input size).
|
155 |
+
max_length (`int`, *optional*):
|
156 |
+
Controls the maximum length to use by one of the truncation/padding parameters.
|
157 |
+
|
158 |
+
If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
|
159 |
+
is required by one of the truncation/padding parameters. If the model has no specific maximum input
|
160 |
+
length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
161 |
+
stride (`int`, *optional*, defaults to 0):
|
162 |
+
If set to a number along with `max_length`, the overflowing tokens returned when
|
163 |
+
`return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
|
164 |
+
returned to provide some overlap between truncated and overflowing sequences. The value of this
|
165 |
+
argument defines the number of overlapping tokens.
|
166 |
+
is_split_into_words (`bool`, *optional*, defaults to `False`):
|
167 |
+
Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
|
168 |
+
tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
|
169 |
+
which it will tokenize. This is useful for NER or token classification.
|
170 |
+
pad_to_multiple_of (`int`, *optional*):
|
171 |
+
If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated.
|
172 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
173 |
+
`>= 7.5` (Volta).
|
174 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
175 |
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
176 |
+
|
177 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
178 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
179 |
+
- `'np'`: Return Numpy `np.ndarray` objects.
|
180 |
+
"""
|
181 |
+
return self.tokenizer.encode(
|
182 |
+
text,
|
183 |
+
add_special_tokens=add_special_tokens,
|
184 |
+
padding=padding,
|
185 |
+
truncation=truncation,
|
186 |
+
max_length=max_length,
|
187 |
+
stride=stride,
|
188 |
+
return_tensors=return_tensors,
|
189 |
+
)
|
190 |
+
|
191 |
+
def decode(
|
192 |
+
self,
|
193 |
+
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor"],
|
194 |
+
*, # Enforce keyword-only arguments
|
195 |
+
skip_special_tokens: bool = False,
|
196 |
+
clean_up_tokenization_spaces: bool = None,
|
197 |
+
**kwargs,
|
198 |
+
) -> str:
|
199 |
+
"""
|
200 |
+
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
201 |
+
tokens and clean up tokenization spaces.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
205 |
+
List of tokenized input ids. Can be obtained using the `__call__` method.
|
206 |
+
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
207 |
+
Whether or not to remove special tokens in the decoding.
|
208 |
+
clean_up_tokenization_spaces (`bool`, *optional*):
|
209 |
+
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
210 |
+
`self.clean_up_tokenization_spaces`.
|
211 |
+
kwargs (additional keyword arguments, *optional*):
|
212 |
+
Will be passed to the underlying model specific decode method.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
`str`: The decoded sentence.
|
216 |
+
"""
|
217 |
+
return self.tokenizer.decode(
|
218 |
+
token_ids,
|
219 |
+
skip_special_tokens=skip_special_tokens,
|
220 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
221 |
+
**kwargs,
|
222 |
+
)
|
223 |
+
|
224 |
+
def apply_chat_template(
|
225 |
+
self,
|
226 |
+
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
227 |
+
*,
|
228 |
+
add_generation_prompt: bool = False,
|
229 |
+
tokenize: bool = True,
|
230 |
+
padding: bool = False,
|
231 |
+
truncation: bool = False,
|
232 |
+
max_length: Optional[int] = None,
|
233 |
+
return_tensors: Optional[str] = None,
|
234 |
+
return_dict: bool = False,
|
235 |
+
return_assistant_tokens_mask: bool = False,
|
236 |
+
generation_prefix: str = "",
|
237 |
+
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
238 |
+
**kwargs,
|
239 |
+
):
|
240 |
+
"""
|
241 |
+
Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token
|
242 |
+
ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to determine the format and control tokens to use when converting.
|
243 |
+
|
244 |
+
More details can be found at https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template
|
245 |
+
|
246 |
+
Args:
|
247 |
+
conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts
|
248 |
+
with "role" and "content" keys, representing the chat history so far.
|
249 |
+
add_generation_prompt (bool, *optional*):
|
250 |
+
If this is set, a prompt with the token(s) that indicate
|
251 |
+
the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
|
252 |
+
Note that this argument will be passed to the chat template, and so it must be supported in the
|
253 |
+
template for this argument to have any effect.
|
254 |
+
continue_final_message (bool, *optional*):
|
255 |
+
If this is set, the chat will be formatted so that the final
|
256 |
+
message in the chat is open-ended, without any EOS tokens. The model will continue this message
|
257 |
+
rather than starting a new one. This allows you to "prefill" part of
|
258 |
+
the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
|
259 |
+
tokenize (`bool`, defaults to `True`):
|
260 |
+
Whether to tokenize the output. If `False`, the output will be a string.
|
261 |
+
padding (`bool`, defaults to `False`):
|
262 |
+
Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`.
|
263 |
+
truncation (`bool`, defaults to `False`):
|
264 |
+
Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
|
265 |
+
max_length (`int`, *optional*):
|
266 |
+
Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
|
267 |
+
not specified, the tokenizer's `max_length` attribute will be used as a default.
|
268 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
269 |
+
If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable
|
270 |
+
values are:
|
271 |
+
- `'tf'`: Return TensorFlow `tf.Tensor` objects.
|
272 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
273 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
274 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
275 |
+
return_dict (`bool`, defaults to `False`):
|
276 |
+
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
|
277 |
+
generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "".
|
278 |
+
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
|
279 |
+
return_assistant_tokens_mask (`bool`, defaults to `False`):
|
280 |
+
Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
|
281 |
+
the mask will contain 1. For user and system tokens, the mask will contain 0.
|
282 |
+
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
|
283 |
+
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
`Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This
|
287 |
+
output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is
|
288 |
+
set, will return a dict of tokenizer outputs instead.
|
289 |
+
"""
|
290 |
+
if not self.is_instruct_model:
|
291 |
+
raise ValueError(
|
292 |
+
"apply_chat_template is only supported for instruct models. You should pass argument is_instruct_model=True to the TextTokenizer constructor."
|
293 |
+
)
|
294 |
+
# Since generation_prefix is added to the text in the end, ensure that the setting is correct
|
295 |
+
if generation_prefix:
|
296 |
+
assert not tokenize, "tokenize must be False when generation_prefix is provided."
|
297 |
+
assert add_generation_prompt, "add_generation_prompt must be set when generation_prefix is provided."
|
298 |
+
formatted_text: Union[str, List[int]] = self.tokenizer.apply_chat_template(
|
299 |
+
conversation,
|
300 |
+
add_generation_prompt=add_generation_prompt,
|
301 |
+
tokenize=tokenize,
|
302 |
+
padding=padding,
|
303 |
+
truncation=truncation,
|
304 |
+
max_length=max_length,
|
305 |
+
return_tensors=return_tensors,
|
306 |
+
return_dict=return_dict,
|
307 |
+
return_assistant_tokens_mask=return_assistant_tokens_mask,
|
308 |
+
tokenizer_kwargs=tokenizer_kwargs,
|
309 |
+
**kwargs,
|
310 |
+
)
|
311 |
+
if generation_prefix:
|
312 |
+
formatted_text: str = formatted_text + generation_prefix
|
313 |
+
log.debug(
|
314 |
+
f"Adding generation prefix: {generation_prefix} to the formatted text\n"
|
315 |
+
f"Formatted text: {formatted_text}"
|
316 |
+
)
|
317 |
+
return formatted_text
|
ar_tokenizer_utils.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Any
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from einops import pack, rearrange, unpack
|
20 |
+
|
21 |
+
|
22 |
+
def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
23 |
+
batch_size = x.shape[0]
|
24 |
+
return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
|
25 |
+
|
26 |
+
|
27 |
+
def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
|
28 |
+
return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
29 |
+
|
30 |
+
|
31 |
+
def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
32 |
+
batch_size, height = x.shape[0], x.shape[-2]
|
33 |
+
return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
|
34 |
+
|
35 |
+
|
36 |
+
def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
|
37 |
+
return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
|
38 |
+
|
39 |
+
|
40 |
+
def cast_tuple(t: Any, length: int = 1) -> Any:
|
41 |
+
return t if isinstance(t, tuple) else ((t,) * length)
|
42 |
+
|
43 |
+
|
44 |
+
def replication_pad(x):
|
45 |
+
return torch.cat([x[:, :, :1, ...], x], dim=2)
|
46 |
+
|
47 |
+
|
48 |
+
def divisible_by(num: int, den: int) -> bool:
|
49 |
+
return (num % den) == 0
|
50 |
+
|
51 |
+
|
52 |
+
def is_odd(n: int) -> bool:
|
53 |
+
return not divisible_by(n, 2)
|
54 |
+
|
55 |
+
|
56 |
+
def nonlinearity(x):
|
57 |
+
return x * torch.sigmoid(x)
|
58 |
+
|
59 |
+
|
60 |
+
class CausalNormalize(torch.nn.Module):
|
61 |
+
def __init__(self, in_channels, num_groups=1):
|
62 |
+
super().__init__()
|
63 |
+
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
64 |
+
self.num_groups = num_groups
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
# if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
|
68 |
+
# All new models should use num_groups=1, otherwise causality is not guaranteed.
|
69 |
+
if self.num_groups == 1:
|
70 |
+
x, batch_size = time2batch(x)
|
71 |
+
return batch2time(self.norm(x), batch_size)
|
72 |
+
return self.norm(x)
|
73 |
+
|
74 |
+
|
75 |
+
def exists(v):
|
76 |
+
return v is not None
|
77 |
+
|
78 |
+
|
79 |
+
def default(*args):
|
80 |
+
for arg in args:
|
81 |
+
if exists(arg):
|
82 |
+
return arg
|
83 |
+
return None
|
84 |
+
|
85 |
+
|
86 |
+
def pack_one(t, pattern):
|
87 |
+
return pack([t], pattern)
|
88 |
+
|
89 |
+
|
90 |
+
def unpack_one(t, ps, pattern):
|
91 |
+
return unpack(t, ps, pattern)[0]
|
92 |
+
|
93 |
+
|
94 |
+
def round_ste(z: torch.Tensor) -> torch.Tensor:
|
95 |
+
"""Round with straight through gradients."""
|
96 |
+
zhat = z.round()
|
97 |
+
return z + (zhat - z).detach()
|
98 |
+
|
99 |
+
|
100 |
+
def log(t, eps=1e-5):
|
101 |
+
return t.clamp(min=eps).log()
|
ar_transformer.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Any, Dict, Optional
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn.modules.module import _IncompatibleKeys
|
21 |
+
|
22 |
+
from .ar_modules_attention import Attention
|
23 |
+
from .ar_modules_embedding import (
|
24 |
+
RotaryPositionEmbeddingPytorchV1,
|
25 |
+
RotaryPositionEmbeddingPytorchV2,
|
26 |
+
SinCosPosEmbAxisTE,
|
27 |
+
)
|
28 |
+
from .ar_modules_mlp import MLP
|
29 |
+
from .ar_modules_normalization import create_norm
|
30 |
+
from .checkpoint import process_state_dict, substrings_to_ignore
|
31 |
+
from .ar_utils_misc import maybe_convert_to_namespace
|
32 |
+
from .log import log
|
33 |
+
|
34 |
+
|
35 |
+
class TransformerBlock(nn.Module):
|
36 |
+
"""
|
37 |
+
A single transformer block consisting of an attention layer and a feed-forward layer.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, layer_id: int, args=None):
|
41 |
+
"""
|
42 |
+
Initializes the TransformerBlock module.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
layer_id: The ID of the transformer block.
|
46 |
+
args: The model arguments containing hyperparameters.
|
47 |
+
"""
|
48 |
+
super().__init__()
|
49 |
+
args = maybe_convert_to_namespace(args)
|
50 |
+
attention_args = {
|
51 |
+
"n_heads": args["n_heads"],
|
52 |
+
"n_kv_heads": args["n_kv_heads"],
|
53 |
+
"dim": args["dim"],
|
54 |
+
"context_dim": None,
|
55 |
+
"max_batch_size": args["max_batch_size"],
|
56 |
+
"max_seq_len": args["max_seq_len"],
|
57 |
+
"use_qk_normalization": args["use_qk_normalization"],
|
58 |
+
"causal_mask": args["causal_mask"],
|
59 |
+
"head_dim": args["head_dim"],
|
60 |
+
"fuse_qkv": getattr(args, "fuse_qkv", False),
|
61 |
+
"precision": getattr(args, "precision", "bfloat16"),
|
62 |
+
"attn_type": getattr(args, "attn_type", "self"),
|
63 |
+
}
|
64 |
+
self.attention = Attention(**attention_args)
|
65 |
+
|
66 |
+
self.has_cross_attention = False
|
67 |
+
self.cross_attention, self.cross_attention_norm = None, None
|
68 |
+
|
69 |
+
if args["insert_cross_attn"] and layer_id % args["insert_cross_attn_every_k_layers"] == 0:
|
70 |
+
self.has_cross_attention = True
|
71 |
+
cross_attention_args = attention_args.copy()
|
72 |
+
cross_attention_args.update({"context_dim": args["context_dim"], "fuse_qkv": False, "attn_type": "cross"})
|
73 |
+
self.cross_attention = Attention(**cross_attention_args)
|
74 |
+
self.cross_attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"])
|
75 |
+
|
76 |
+
self.feed_forward = MLP(
|
77 |
+
dim=args["dim"],
|
78 |
+
hidden_dim=args["ffn_hidden_size"],
|
79 |
+
)
|
80 |
+
self.layer_id = layer_id
|
81 |
+
self.attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"])
|
82 |
+
self.ffn_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"])
|
83 |
+
|
84 |
+
def forward(
|
85 |
+
self,
|
86 |
+
x: torch.Tensor,
|
87 |
+
rope: RotaryPositionEmbeddingPytorchV2,
|
88 |
+
input_pos: Optional[torch.Tensor] = None,
|
89 |
+
mask: Optional[torch.Tensor] = None,
|
90 |
+
context: Optional[torch.Tensor] = None,
|
91 |
+
context_mask: Optional[torch.Tensor] = None,
|
92 |
+
) -> torch.Tensor:
|
93 |
+
"""
|
94 |
+
Performs the forward pass of the TransformerBlock module.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
x: The input tensor.
|
98 |
+
input_pos: The position of the current sequence. Used in inference (with KV cache) only.
|
99 |
+
freqs_cis: The precomputed frequency values for rotary position embeddings.
|
100 |
+
mask: The attention mask tensor.
|
101 |
+
context (Optional[torch.Tensor]): The context tensor added via cross-attn.
|
102 |
+
context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
The output tensor after applying the transformer block.
|
106 |
+
"""
|
107 |
+
# Apply attention and residual connection
|
108 |
+
h = x + self.attention(self.attention_norm(x), rope=rope, input_pos=input_pos, mask=mask)
|
109 |
+
|
110 |
+
# If insert cross-attention, apply CA and residual connection
|
111 |
+
if self.has_cross_attention:
|
112 |
+
h = h + self.cross_attention(
|
113 |
+
self.cross_attention_norm(h), rope=rope, input_pos=input_pos, mask=context_mask, context=context
|
114 |
+
)
|
115 |
+
|
116 |
+
# Apply feed-forward network and residual connection
|
117 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
118 |
+
return out
|
119 |
+
|
120 |
+
def init_weights(self):
|
121 |
+
"""
|
122 |
+
Initializes the weights of the transformer block.
|
123 |
+
"""
|
124 |
+
for norm in (self.attention_norm, self.ffn_norm):
|
125 |
+
norm.reset_parameters()
|
126 |
+
self.attention.init_weights(self.weight_init_std)
|
127 |
+
self.feed_forward.init_weights(self.weight_init_std)
|
128 |
+
|
129 |
+
if self.has_cross_attention:
|
130 |
+
self.cross_attention_norm.reset_parameters()
|
131 |
+
self.cross_attention.init_weights(self.weight_init_std)
|
132 |
+
# zero-init the final output layer of cross-attention
|
133 |
+
# nn.init.zeros_(self.cross_attention.wo.weight)
|
134 |
+
|
135 |
+
|
136 |
+
class Transformer(nn.Module):
|
137 |
+
"""
|
138 |
+
The Transformer network consisting of transformer blocks.
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(self, params, tokenizer_config=None, init_weights: bool = True):
|
142 |
+
"""
|
143 |
+
Initializes the Transformer module.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
params: The model parameters containing hyperparameters.
|
147 |
+
tokenizer_config: The model tokenizer configuration.
|
148 |
+
init_weights (bool): Whether to initialize the weights of the transformer following
|
149 |
+
TorchTitan's Llama3 initialization scheme.
|
150 |
+
"""
|
151 |
+
super().__init__()
|
152 |
+
# Check if self.params is an OmegaConf DictConfig instance
|
153 |
+
self.params = maybe_convert_to_namespace(params)
|
154 |
+
self.vocab_size = params["vocab_size"]
|
155 |
+
self.n_layers = params["n_layers"]
|
156 |
+
self.precision = getattr(torch, params["precision"])
|
157 |
+
self.tokenizer_config = tokenizer_config
|
158 |
+
self.num_video_frames = params["num_video_frames"]
|
159 |
+
|
160 |
+
# Token embeddings
|
161 |
+
self.tok_embeddings = self._create_token_embeddings()
|
162 |
+
self.rope_config = self._create_rope_config()
|
163 |
+
|
164 |
+
# Transformer layers
|
165 |
+
self.layers = nn.ModuleList(
|
166 |
+
[TransformerBlock(layer_id, self.params).to(self.precision) for layer_id in range(self.n_layers)]
|
167 |
+
)
|
168 |
+
|
169 |
+
# Final layer normalization
|
170 |
+
self.norm = create_norm(self.params["norm_type"], dim=self.params["dim"], eps=self.params["norm_eps"]).to(
|
171 |
+
self.precision
|
172 |
+
)
|
173 |
+
if self.params["pytorch_rope_version"] == "v1":
|
174 |
+
self.rope = RotaryPositionEmbeddingPytorchV1(**self.rope_config)
|
175 |
+
elif self.params["pytorch_rope_version"] == "v2":
|
176 |
+
# Rotary position embeddings
|
177 |
+
training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None
|
178 |
+
self.rope = RotaryPositionEmbeddingPytorchV2(
|
179 |
+
seq_len=self.params["max_seq_len"], training_type=training_type, **self.rope_config
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
raise ValueError(f"Invalid PyTorch RoPE version: {self.params['pytorch_rope_version']}")
|
183 |
+
# Causal mask
|
184 |
+
self.causal_mask = torch.tril(
|
185 |
+
torch.ones(self.params["max_seq_len"], self.params["max_seq_len"], dtype=torch.bool)
|
186 |
+
).cuda()
|
187 |
+
|
188 |
+
# Output projection
|
189 |
+
self.output = self._create_output_projection()
|
190 |
+
|
191 |
+
# Freeze network parameters for finetuning w/ cross-attention
|
192 |
+
self.has_cross_attention = getattr(params, "insert_cross_attn", False)
|
193 |
+
|
194 |
+
# Absolute position embeddings
|
195 |
+
if self.params["apply_abs_pos_emb"]:
|
196 |
+
self.pos_emb_config = self._create_abs_pos_emb_config()
|
197 |
+
self.pos_emb, self.abs_pos_emb = self._initialize_abs_pos_emb()
|
198 |
+
|
199 |
+
def _create_rope_config(self) -> Dict:
|
200 |
+
shape_map = {
|
201 |
+
"3D": self.params["video_latent_shape"],
|
202 |
+
"1D": None,
|
203 |
+
}
|
204 |
+
latent_shape = shape_map.get(self.params["rope_dim"], None)
|
205 |
+
head_dim = self.params["head_dim"]
|
206 |
+
if head_dim is None:
|
207 |
+
head_dim = self.params["dim"] // self.params["n_heads"]
|
208 |
+
return {
|
209 |
+
"dim": head_dim,
|
210 |
+
"max_position_embeddings": self.params["max_seq_len"],
|
211 |
+
"original_max_position_embeddings": self.params["original_seq_len"],
|
212 |
+
"rope_theta": self.params["rope_theta"],
|
213 |
+
"apply_yarn": self.params["apply_yarn"],
|
214 |
+
"scale": self.params["yarn_scale"],
|
215 |
+
"beta_fast": self.params["yarn_beta_fast"],
|
216 |
+
"beta_slow": self.params["yarn_beta_slow"],
|
217 |
+
"rope_dim": self.params["rope_dim"],
|
218 |
+
"latent_shape": latent_shape,
|
219 |
+
"original_latent_shape": self.params["original_latent_shape"],
|
220 |
+
"pad_to_multiple_of": self.params["pad_to_multiple_of"],
|
221 |
+
}
|
222 |
+
|
223 |
+
def _create_abs_pos_emb_config(self):
|
224 |
+
shape_map = {
|
225 |
+
"3D": self.params["video_latent_shape"],
|
226 |
+
"1D": None,
|
227 |
+
}
|
228 |
+
latent_shape = shape_map.get(self.params["rope_dim"], None)
|
229 |
+
return {
|
230 |
+
"dim": self.params["dim"],
|
231 |
+
"latent_shape": latent_shape,
|
232 |
+
"pad_to_multiple_of": self.params["pad_to_multiple_of"],
|
233 |
+
}
|
234 |
+
|
235 |
+
def _create_token_embeddings(self, vocab_size: int = None):
|
236 |
+
"""
|
237 |
+
Create token embeddings.
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
nn.Module: Token embeddings module.
|
241 |
+
"""
|
242 |
+
if vocab_size is None:
|
243 |
+
vocab_size = self.params["vocab_size"]
|
244 |
+
return nn.Embedding(vocab_size, self.params["dim"]).to(self.precision)
|
245 |
+
|
246 |
+
def _create_output_projection(self, vocab_size: int = None):
|
247 |
+
"""
|
248 |
+
Create the output projection layer.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
vocab_size (int): Vocabulary size (to override the default vocab size).
|
252 |
+
Returns:
|
253 |
+
LinearTE: Output projection layer.
|
254 |
+
"""
|
255 |
+
if vocab_size is None:
|
256 |
+
vocab_size = self.params["vocab_size"]
|
257 |
+
return nn.Linear(self.params["dim"], vocab_size, bias=False).to(self.precision)
|
258 |
+
|
259 |
+
def _initialize_abs_pos_emb(self):
|
260 |
+
pos_emb = SinCosPosEmbAxisTE(**self.pos_emb_config)
|
261 |
+
training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None
|
262 |
+
abs_pos_emb = pos_emb.forward(training_type=training_type)
|
263 |
+
return pos_emb, abs_pos_emb
|
264 |
+
|
265 |
+
def forward(
|
266 |
+
self,
|
267 |
+
tokens: Optional[torch.Tensor] = None,
|
268 |
+
input_pos: Optional[torch.Tensor] = None,
|
269 |
+
token_embeddings: Optional[torch.Tensor] = None,
|
270 |
+
context: Optional[torch.Tensor] = None,
|
271 |
+
context_mask: Optional[torch.Tensor] = None,
|
272 |
+
) -> torch.Tensor:
|
273 |
+
"""
|
274 |
+
Performs the forward pass of the Transformer module.
|
275 |
+
|
276 |
+
Args:
|
277 |
+
tokens (torch.Tensor, optional): The input tensor of token IDs.
|
278 |
+
input_pos (Optional[torch.Tensor]): The position of the current sequence. Used in inference with KV cache.
|
279 |
+
token_embeddings (torch.Tensor, optional): Precomputed token embeddings. If provided, tokens should be None.
|
280 |
+
context (Optional[torch.Tensor]): The context tensor added via cross-attn.
|
281 |
+
context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor.
|
282 |
+
Returns:
|
283 |
+
The output tensor after applying the transformer layers.
|
284 |
+
"""
|
285 |
+
# Token embeddings
|
286 |
+
assert (
|
287 |
+
tokens is None or token_embeddings is None
|
288 |
+
), "Either tokens or token_embeddings should be provided, not both."
|
289 |
+
|
290 |
+
if token_embeddings is None:
|
291 |
+
seq_len = tokens.shape[1]
|
292 |
+
h = self.tok_embeddings(tokens)
|
293 |
+
else:
|
294 |
+
seq_len = token_embeddings.shape[1]
|
295 |
+
h = token_embeddings
|
296 |
+
|
297 |
+
# Create attention mask
|
298 |
+
mask = self._create_attention_mask(input_pos=input_pos)
|
299 |
+
|
300 |
+
# Prepare layer arguments
|
301 |
+
layer_kwargs = self._prepare_layer_kwargs(
|
302 |
+
input_pos=input_pos,
|
303 |
+
mask=mask,
|
304 |
+
context=context,
|
305 |
+
context_mask=context_mask,
|
306 |
+
)
|
307 |
+
|
308 |
+
# Apply transformer layers
|
309 |
+
for layer in self.layers:
|
310 |
+
if self.params["apply_abs_pos_emb"]:
|
311 |
+
h = self.apply_abs_pos_emb(h, input_pos=input_pos)
|
312 |
+
h = layer(h, **layer_kwargs)
|
313 |
+
|
314 |
+
# Apply final layer normalization
|
315 |
+
h = self.norm(h)
|
316 |
+
|
317 |
+
# Output linear projection
|
318 |
+
output = self.output(h)
|
319 |
+
return output
|
320 |
+
|
321 |
+
def _create_attention_mask(self, input_pos: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
322 |
+
"""
|
323 |
+
Creates an attention mask for the transformer layers.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
input_pos[torch.Tensor]: The position of input sequence (used for inference only).
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
Optional[torch.Tensor]: The attention mask, or None for causal mask.
|
330 |
+
"""
|
331 |
+
|
332 |
+
assert input_pos is not None, "input_pos must be provided for inference"
|
333 |
+
mask = self.causal_mask[input_pos]
|
334 |
+
return mask
|
335 |
+
|
336 |
+
def _prepare_layer_kwargs(
|
337 |
+
self,
|
338 |
+
input_pos: Optional[torch.Tensor],
|
339 |
+
mask: Optional[torch.Tensor],
|
340 |
+
context: Optional[torch.Tensor],
|
341 |
+
context_mask: Optional[torch.Tensor],
|
342 |
+
) -> Dict[str, Any]:
|
343 |
+
"""
|
344 |
+
Prepares the keyword arguments for transformer layers.
|
345 |
+
|
346 |
+
Args:
|
347 |
+
input_pos (Optional[torch.Tensor]): The position of the current sequence.
|
348 |
+
mask (Optional[torch.Tensor]): The attention mask.
|
349 |
+
context (Optional[torch.Tensor]): The context tensor added via cross-attn.
|
350 |
+
context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
Dict[str, Any]: A dictionary of keyword arguments for the transformer layers.
|
354 |
+
"""
|
355 |
+
if context is not None:
|
356 |
+
context = context.to(self.precision)
|
357 |
+
|
358 |
+
if isinstance(mask, torch.Tensor) and mask.ndim == 2:
|
359 |
+
mask = mask[None, None, :, :]
|
360 |
+
if isinstance(context_mask, torch.Tensor) and context_mask.ndim == 2:
|
361 |
+
context_mask = context_mask[None, None, :, :]
|
362 |
+
|
363 |
+
layer_kwargs = {
|
364 |
+
"mask": mask,
|
365 |
+
"context": context,
|
366 |
+
"context_mask": context_mask,
|
367 |
+
}
|
368 |
+
|
369 |
+
layer_kwargs["input_pos"] = input_pos
|
370 |
+
layer_kwargs["rope"] = self.rope
|
371 |
+
|
372 |
+
return layer_kwargs
|
373 |
+
|
374 |
+
def apply_abs_pos_emb(self, x: torch.Tensor, input_pos: int = None) -> torch.Tensor:
|
375 |
+
"""
|
376 |
+
Applies the absolute position embeddings to the input tensor.
|
377 |
+
"""
|
378 |
+
abs_pos_emb = self.abs_pos_emb
|
379 |
+
abs_pos_emb = abs_pos_emb[:, input_pos, :] if input_pos is not None else abs_pos_emb
|
380 |
+
return x + abs_pos_emb
|
381 |
+
|
382 |
+
@torch.no_grad()
|
383 |
+
def expand_vocab(
|
384 |
+
self, new_vocab_size: int, init_method: str = "gaussian", multiple_of=64, expand_output_layer=True
|
385 |
+
):
|
386 |
+
"""
|
387 |
+
Expands the vocabulary of the model to the new size.
|
388 |
+
|
389 |
+
Args:
|
390 |
+
new_vocab_size (int): The new vocabulary size.
|
391 |
+
init_method (str): The initialization method for new embeddings.
|
392 |
+
Can be "zero" or "gaussian". Default is "gaussian".
|
393 |
+
multiple_of (int): The new vocabulary size must be a multiple of this value. Defaults to 64 to fully
|
394 |
+
leverage the power of NVIDIA TensorCore (source 1: https://x.com/karpathy/status/1621578354024677377,
|
395 |
+
source 2: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc)
|
396 |
+
expand_output_layer (bool): Whether to also expand the output layer. Defaults to True.
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
None
|
400 |
+
"""
|
401 |
+
if new_vocab_size <= self.vocab_size:
|
402 |
+
raise ValueError(
|
403 |
+
f"New vocabulary size ({new_vocab_size}) must be " f"larger than current size ({self.vocab_size})"
|
404 |
+
)
|
405 |
+
if new_vocab_size % multiple_of != 0:
|
406 |
+
log.debug(f"New vocabulary size must be a multiple of {multiple_of}. Obtained {new_vocab_size}.")
|
407 |
+
new_vocab_size = (new_vocab_size // multiple_of + 1) * multiple_of
|
408 |
+
log.debug(f"Rounded vocabulary size to {new_vocab_size}.")
|
409 |
+
# Resize token embeddings
|
410 |
+
old_embeddings = self.tok_embeddings
|
411 |
+
tensor_kwargs = {"device": old_embeddings.weight.device, "dtype": old_embeddings.weight.dtype}
|
412 |
+
self.tok_embeddings = self._create_token_embeddings(vocab_size=new_vocab_size).to(**tensor_kwargs)
|
413 |
+
# Initialize new embeddings
|
414 |
+
if init_method not in ["zero", "gaussian"]:
|
415 |
+
raise ValueError(f"Unknown initialization method: {init_method}")
|
416 |
+
# The default initialization of nn.Embedding is Gaussian, so we don't need to do anything
|
417 |
+
# if init_method == "gaussian". Only if init_method == "zero", we need to zero out the new embeddings.
|
418 |
+
if init_method == "zero":
|
419 |
+
self.tok_embeddings.weight.data[self.vocab_size :].zero_()
|
420 |
+
|
421 |
+
# Copy old embeddings
|
422 |
+
log.debug(
|
423 |
+
f"old_embeddings: {old_embeddings.weight.data.shape}, new_embeddings: {self.tok_embeddings.weight.data.shape}, vocab_size: {self.vocab_size}"
|
424 |
+
)
|
425 |
+
self.tok_embeddings.weight.data[: self.vocab_size] = old_embeddings.weight.data
|
426 |
+
# Resize output layer
|
427 |
+
old_output = self.output
|
428 |
+
self.output = self._create_output_projection(vocab_size=new_vocab_size if expand_output_layer else None)
|
429 |
+
|
430 |
+
# Initialize new output weights
|
431 |
+
self.output.weight.data[self.vocab_size :].zero_()
|
432 |
+
# Copy old output weights
|
433 |
+
self.output.weight.data[: self.vocab_size] = old_output.weight.data
|
434 |
+
|
435 |
+
# Update vocab size
|
436 |
+
self.vocab_size = new_vocab_size
|
437 |
+
log.debug(f"Expanded vocabulary size to {new_vocab_size}")
|
438 |
+
|
439 |
+
def state_dict(self, *args, **kwargs):
|
440 |
+
"""
|
441 |
+
Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8).
|
442 |
+
"""
|
443 |
+
state_dict = super().state_dict(*args, **kwargs)
|
444 |
+
return process_state_dict(state_dict)
|
445 |
+
|
446 |
+
def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False):
|
447 |
+
"""
|
448 |
+
Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by
|
449 |
+
TransformerEngine for FP8).
|
450 |
+
"""
|
451 |
+
state_dict = process_state_dict(state_dict)
|
452 |
+
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign)
|
453 |
+
if strict:
|
454 |
+
actual_missing_keys = []
|
455 |
+
for key in missing_keys:
|
456 |
+
if not any(substring in key for substring in substrings_to_ignore):
|
457 |
+
actual_missing_keys.append(key)
|
458 |
+
if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0:
|
459 |
+
raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}")
|
460 |
+
missing_keys = actual_missing_keys
|
461 |
+
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
ar_utils_misc.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from omegaconf import DictConfig, OmegaConf
|
17 |
+
|
18 |
+
|
19 |
+
class CustomSimpleNamespace:
|
20 |
+
"""
|
21 |
+
A simple namespace class that supports both attribute-style and dictionary-style access.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, d):
|
25 |
+
self._d = d
|
26 |
+
|
27 |
+
def __getattr__(self, attr):
|
28 |
+
# Attribute-style access: config.key
|
29 |
+
try:
|
30 |
+
return self._d[attr]
|
31 |
+
except KeyError:
|
32 |
+
raise AttributeError(f"'CustomSimpleNamespace' object has no attribute '{attr}'")
|
33 |
+
|
34 |
+
def __getitem__(self, key):
|
35 |
+
# Dictionary-style access: config['key']
|
36 |
+
return self._d[key]
|
37 |
+
|
38 |
+
|
39 |
+
def maybe_convert_to_namespace(config):
|
40 |
+
"""
|
41 |
+
This function cast a OmegaConf's DictConfig or a standard dict to CustomSimpleNamespace, which supports both
|
42 |
+
attribute-style and dictionary-style access.
|
43 |
+
Note: We need to convert OmegaConf's DictConfig since it is not compatible with torch.compile.
|
44 |
+
"""
|
45 |
+
# If input is OmegaConf's DictConfig, convert to a standard dict
|
46 |
+
if isinstance(config, DictConfig):
|
47 |
+
config = OmegaConf.to_container(config, resolve=True)
|
48 |
+
|
49 |
+
if isinstance(config, dict):
|
50 |
+
return CustomSimpleNamespace(config)
|
51 |
+
else:
|
52 |
+
return config
|
attention.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import List, Optional
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import transformer_engine as te
|
21 |
+
from einops import rearrange
|
22 |
+
from torch import nn
|
23 |
+
from torch.utils.checkpoint import checkpoint
|
24 |
+
from transformer_engine.pytorch.attention import DotProductAttention, apply_rotary_pos_emb
|
25 |
+
|
26 |
+
# ---------------------- Feed Forward Network -----------------------
|
27 |
+
|
28 |
+
|
29 |
+
class FeedForward(nn.Module):
|
30 |
+
"""
|
31 |
+
Transformer FFN with optional gating
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
d_model (int): Dimensionality of input features.
|
35 |
+
d_ff (int): Dimensionality of the hidden layer.
|
36 |
+
dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
|
37 |
+
activation (callable, optional): The activation function applied after the first linear layer.
|
38 |
+
Defaults to nn.ReLU().
|
39 |
+
is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
|
40 |
+
Defaults to False.
|
41 |
+
bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
|
42 |
+
|
43 |
+
Example:
|
44 |
+
>>> ff = FeedForward(d_model=512, d_ff=2048)
|
45 |
+
>>> x = torch.randn(64, 10, 512) # Example input tensor
|
46 |
+
>>> output = ff(x)
|
47 |
+
>>> print(output.shape) # Expected shape: (64, 10, 512)
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
d_model: int,
|
53 |
+
d_ff: int,
|
54 |
+
dropout: float = 0.1,
|
55 |
+
activation=nn.ReLU(),
|
56 |
+
is_gated: bool = False,
|
57 |
+
bias: bool = False,
|
58 |
+
) -> None:
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
self.layer1 = nn.Linear(d_model, d_ff, bias=bias)
|
62 |
+
self.layer2 = nn.Linear(d_ff, d_model, bias=bias)
|
63 |
+
|
64 |
+
self.dropout = nn.Dropout(dropout)
|
65 |
+
self.activation = activation
|
66 |
+
self.is_gated = is_gated
|
67 |
+
if is_gated:
|
68 |
+
self.linear_gate = nn.Linear(d_model, d_ff, bias=False)
|
69 |
+
|
70 |
+
def forward(self, x: torch.Tensor):
|
71 |
+
g = self.activation(self.layer1(x))
|
72 |
+
if self.is_gated:
|
73 |
+
x = g * self.linear_gate(x)
|
74 |
+
else:
|
75 |
+
x = g
|
76 |
+
assert self.dropout.p == 0.0, "we skip dropout"
|
77 |
+
return self.layer2(x)
|
78 |
+
|
79 |
+
|
80 |
+
class GPT2FeedForward(FeedForward):
|
81 |
+
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False):
|
82 |
+
super().__init__(
|
83 |
+
d_model=d_model,
|
84 |
+
d_ff=d_ff,
|
85 |
+
dropout=dropout,
|
86 |
+
activation=nn.GELU(),
|
87 |
+
is_gated=False,
|
88 |
+
bias=bias,
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x: torch.Tensor):
|
92 |
+
assert self.dropout.p == 0.0, "we skip dropout"
|
93 |
+
|
94 |
+
x = self.layer1(x)
|
95 |
+
|
96 |
+
def activation_layer2_forward(x):
|
97 |
+
x = self.activation(x)
|
98 |
+
x = self.layer2(x)
|
99 |
+
return x
|
100 |
+
|
101 |
+
x = checkpoint(activation_layer2_forward, x, use_reentrant=False)
|
102 |
+
return x
|
103 |
+
|
104 |
+
|
105 |
+
# ---------------------- Normalization Layer -----------------------
|
106 |
+
|
107 |
+
|
108 |
+
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
|
109 |
+
"""
|
110 |
+
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): The input tensor to normalize.
|
114 |
+
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
|
115 |
+
eps (float, optional): A small constant to ensure numerical stability during division.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: The normalized tensor.
|
119 |
+
"""
|
120 |
+
if dim is None:
|
121 |
+
dim = list(range(1, x.ndim))
|
122 |
+
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
123 |
+
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
|
124 |
+
return x / norm.to(x.dtype)
|
125 |
+
|
126 |
+
|
127 |
+
def get_normalization(name: str, channels: int):
|
128 |
+
if name == "I":
|
129 |
+
return nn.Identity()
|
130 |
+
elif name == "R":
|
131 |
+
return te.pytorch.RMSNorm(channels, eps=1e-6)
|
132 |
+
else:
|
133 |
+
raise ValueError(f"Normalization {name} not found")
|
134 |
+
|
135 |
+
|
136 |
+
class BaseAttentionOp(nn.Module):
|
137 |
+
def __init__(self):
|
138 |
+
super().__init__()
|
139 |
+
|
140 |
+
|
141 |
+
class Attention(nn.Module):
|
142 |
+
"""
|
143 |
+
Generalized attention impl.
|
144 |
+
|
145 |
+
Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
|
146 |
+
If `context_dim` is None, self-attention is assumed.
|
147 |
+
|
148 |
+
Parameters:
|
149 |
+
query_dim (int): Dimension of each query vector.
|
150 |
+
context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
|
151 |
+
heads (int, optional): Number of attention heads. Defaults to 8.
|
152 |
+
dim_head (int, optional): Dimension of each head. Defaults to 64.
|
153 |
+
dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
|
154 |
+
attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
|
155 |
+
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
|
156 |
+
out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
|
157 |
+
qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
|
158 |
+
Defaults to "SSI".
|
159 |
+
qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
|
160 |
+
Defaults to 'per_head'. Only support 'per_head'.
|
161 |
+
|
162 |
+
Examples:
|
163 |
+
>>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
|
164 |
+
>>> query = torch.randn(10, 128) # Batch size of 10
|
165 |
+
>>> context = torch.randn(10, 256) # Batch size of 10
|
166 |
+
>>> output = attn(query, context) # Perform the attention operation
|
167 |
+
|
168 |
+
Note:
|
169 |
+
https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
query_dim: int,
|
175 |
+
context_dim=None,
|
176 |
+
heads=8,
|
177 |
+
dim_head=64,
|
178 |
+
dropout=0.0,
|
179 |
+
attn_op: Optional[BaseAttentionOp] = None,
|
180 |
+
qkv_bias: bool = False,
|
181 |
+
out_bias: bool = False,
|
182 |
+
qkv_norm: str = "SSI",
|
183 |
+
qkv_norm_mode: str = "per_head",
|
184 |
+
backend: str = "transformer_engine",
|
185 |
+
qkv_format: str = "bshd",
|
186 |
+
) -> None:
|
187 |
+
super().__init__()
|
188 |
+
|
189 |
+
self.is_selfattn = context_dim is None # self attention
|
190 |
+
|
191 |
+
inner_dim = dim_head * heads
|
192 |
+
context_dim = query_dim if context_dim is None else context_dim
|
193 |
+
|
194 |
+
self.heads = heads
|
195 |
+
self.dim_head = dim_head
|
196 |
+
self.qkv_norm_mode = qkv_norm_mode
|
197 |
+
self.qkv_format = qkv_format
|
198 |
+
|
199 |
+
if self.qkv_norm_mode == "per_head":
|
200 |
+
norm_dim = dim_head
|
201 |
+
else:
|
202 |
+
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
203 |
+
|
204 |
+
self.backend = backend
|
205 |
+
|
206 |
+
self.to_q = nn.Sequential(
|
207 |
+
nn.Linear(query_dim, inner_dim, bias=qkv_bias),
|
208 |
+
get_normalization(qkv_norm[0], norm_dim),
|
209 |
+
)
|
210 |
+
self.to_k = nn.Sequential(
|
211 |
+
nn.Linear(context_dim, inner_dim, bias=qkv_bias),
|
212 |
+
get_normalization(qkv_norm[1], norm_dim),
|
213 |
+
)
|
214 |
+
self.to_v = nn.Sequential(
|
215 |
+
nn.Linear(context_dim, inner_dim, bias=qkv_bias),
|
216 |
+
get_normalization(qkv_norm[2], norm_dim),
|
217 |
+
)
|
218 |
+
|
219 |
+
self.to_out = nn.Sequential(
|
220 |
+
nn.Linear(inner_dim, query_dim, bias=out_bias),
|
221 |
+
nn.Dropout(dropout),
|
222 |
+
)
|
223 |
+
|
224 |
+
if attn_op: # use what is given
|
225 |
+
self.attn_op = attn_op
|
226 |
+
elif self.backend == "transformer_engine":
|
227 |
+
sequence_parallel = False
|
228 |
+
self.attn_op: BaseAttentionOp = DotProductAttention(
|
229 |
+
self.heads,
|
230 |
+
self.dim_head,
|
231 |
+
num_gqa_groups=self.heads,
|
232 |
+
attention_dropout=0,
|
233 |
+
qkv_format=qkv_format,
|
234 |
+
attn_mask_type="no_mask",
|
235 |
+
tp_size=1,
|
236 |
+
tp_group=None,
|
237 |
+
sequence_parallel=sequence_parallel,
|
238 |
+
)
|
239 |
+
else:
|
240 |
+
raise ValueError(f"Backend {backend} not found")
|
241 |
+
|
242 |
+
def cal_qkv(
|
243 |
+
self, x, context=None, mask=None, rope_emb=None, **kwargs
|
244 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
245 |
+
del kwargs
|
246 |
+
|
247 |
+
"""
|
248 |
+
self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
|
249 |
+
Before 07/24/2024, these modules normalize across all heads.
|
250 |
+
After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
|
251 |
+
we support to normalize per head.
|
252 |
+
To keep the checkpoint copatibility with the previous code,
|
253 |
+
we keep the nn.Sequential but call the projection and the normalization layers separately.
|
254 |
+
We use a flag `self.qkv_norm_mode` to control the normalization behavior.
|
255 |
+
The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
|
256 |
+
"""
|
257 |
+
if self.qkv_norm_mode == "per_head":
|
258 |
+
q = self.to_q[0](x)
|
259 |
+
context = x if context is None else context
|
260 |
+
k = self.to_k[0](context)
|
261 |
+
v = self.to_v[0](context)
|
262 |
+
q, k, v = map(
|
263 |
+
lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads, c=self.dim_head),
|
264 |
+
(q, k, v),
|
265 |
+
)
|
266 |
+
else:
|
267 |
+
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
268 |
+
|
269 |
+
q = self.to_q[1](q)
|
270 |
+
k = self.to_k[1](k)
|
271 |
+
v = self.to_v[1](v)
|
272 |
+
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
273 |
+
q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True)
|
274 |
+
k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True)
|
275 |
+
return q, k, v
|
276 |
+
|
277 |
+
def cal_attn(self, q, k, v, mask=None):
|
278 |
+
if self.backend == "transformer_engine":
|
279 |
+
seq_dim = self.qkv_format.index("s")
|
280 |
+
assert (
|
281 |
+
q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1
|
282 |
+
), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version."
|
283 |
+
out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) # [B, Mq, H, V]
|
284 |
+
return self.to_out(out)
|
285 |
+
elif self.backend == "torch":
|
286 |
+
out = self.attn_op(q, k, v, mask=mask) # [B, Mq, H, V]
|
287 |
+
return self.to_out(rearrange(out, " b ... n c -> b ... (n c)"))
|
288 |
+
else:
|
289 |
+
raise ValueError(f"Backend {self.backend} not found")
|
290 |
+
|
291 |
+
def forward(
|
292 |
+
self,
|
293 |
+
x,
|
294 |
+
context=None,
|
295 |
+
mask=None,
|
296 |
+
rope_emb=None,
|
297 |
+
**kwargs,
|
298 |
+
):
|
299 |
+
"""
|
300 |
+
Args:
|
301 |
+
x (Tensor): The query tensor of shape [B, Mq, K]
|
302 |
+
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
303 |
+
"""
|
304 |
+
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
305 |
+
return self.cal_attn(q, k, v, mask)
|
base_world_generation_pipeline.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import gc
|
17 |
+
import os
|
18 |
+
from abc import ABC
|
19 |
+
from typing import Any
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
|
24 |
+
from .t5_text_encoder import CosmosT5TextEncoder
|
25 |
+
from .presets import presets as guardrail_presets
|
26 |
+
|
27 |
+
|
28 |
+
class BaseWorldGenerationPipeline(ABC):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
inference_type: str | None = None,
|
32 |
+
checkpoint_dir: str | None = None,
|
33 |
+
checkpoint_name: str | None = None,
|
34 |
+
enable_text_guardrail: bool = False,
|
35 |
+
enable_video_guardrail: bool = False,
|
36 |
+
offload_network: bool = False,
|
37 |
+
offload_tokenizer: bool = False,
|
38 |
+
offload_text_encoder_model: bool = False,
|
39 |
+
offload_guardrail_models: bool = False,
|
40 |
+
):
|
41 |
+
"""Initialize base world generation pipeline.
|
42 |
+
|
43 |
+
This abstract base class provides core functionality for world generation models including:
|
44 |
+
- Model loading and initialization
|
45 |
+
- Text encoding and embedding
|
46 |
+
- Safety checks and content filtering
|
47 |
+
- Memory management through model offloading
|
48 |
+
|
49 |
+
Args:
|
50 |
+
inference_type: The type of inference pipeline ("text2world" or "video2world")
|
51 |
+
checkpoint_dir: Root directory containing model checkpoints
|
52 |
+
checkpoint_name: Name of the specific checkpoint file to load
|
53 |
+
enable_text_guardrail: If True, validates input prompts for safety
|
54 |
+
enable_video_guardrail: If True, validates generated videos for safety
|
55 |
+
offload_network: If True, moves main model to CPU after inference
|
56 |
+
offload_tokenizer: If True, moves tokenizer to CPU after use
|
57 |
+
offload_text_encoder_model: If True, moves T5 encoder to CPU after encoding
|
58 |
+
offload_guardrail_models: If True, moves safety models to CPU after checks
|
59 |
+
"""
|
60 |
+
self.inference_type = inference_type
|
61 |
+
self.checkpoint_dir = checkpoint_dir
|
62 |
+
self.checkpoint_name = checkpoint_name
|
63 |
+
self.guardrail_dir = "Cosmos-1.0-Guardrail"
|
64 |
+
self.enable_text_guardrail = enable_text_guardrail
|
65 |
+
self.enable_video_guardrail = enable_video_guardrail
|
66 |
+
|
67 |
+
# Add offloading flags
|
68 |
+
self.offload_network = offload_network
|
69 |
+
self.offload_tokenizer = offload_tokenizer
|
70 |
+
self.offload_text_encoder_model = offload_text_encoder_model
|
71 |
+
self.offload_guardrail_models = offload_guardrail_models
|
72 |
+
|
73 |
+
# Initialize model instances
|
74 |
+
self.text_guardrail = None
|
75 |
+
self.video_guardrail = None
|
76 |
+
self.text_encoder = None
|
77 |
+
self.model = None
|
78 |
+
|
79 |
+
self._load_model()
|
80 |
+
|
81 |
+
if not self.offload_text_encoder_model:
|
82 |
+
self._load_text_encoder_model()
|
83 |
+
if not self.offload_guardrail_models:
|
84 |
+
if self.enable_text_guardrail:
|
85 |
+
self._load_text_guardrail()
|
86 |
+
if self.enable_video_guardrail:
|
87 |
+
self._load_video_guardrail()
|
88 |
+
if not self.offload_network:
|
89 |
+
self._load_network()
|
90 |
+
if not self.offload_tokenizer:
|
91 |
+
self._load_tokenizer()
|
92 |
+
|
93 |
+
def _load_tokenizer(self):
|
94 |
+
pass
|
95 |
+
|
96 |
+
def _load_network(self):
|
97 |
+
pass
|
98 |
+
|
99 |
+
def _load_model(self, checkpoint_name: str) -> Any:
|
100 |
+
"""Load the world generation model from a checkpoint.
|
101 |
+
|
102 |
+
This abstract method must be implemented by subclasses to load their specific
|
103 |
+
model architecture and weights.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
checkpoint_name: Path to the model checkpoint file
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
The loaded model instance
|
110 |
+
|
111 |
+
Raises:
|
112 |
+
NotImplementedError: Must be implemented by subclasses
|
113 |
+
"""
|
114 |
+
pass
|
115 |
+
|
116 |
+
def _load_text_encoder_model(self):
|
117 |
+
"""Load the T5 text encoder model.
|
118 |
+
|
119 |
+
Initializes and loads the T5 encoder model used for converting text prompts
|
120 |
+
into embeddings that condition the world generation model.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
Loaded T5 text encoder model instance
|
124 |
+
"""
|
125 |
+
self.text_encoder = CosmosT5TextEncoder(cache_dir=self.checkpoint_dir)
|
126 |
+
|
127 |
+
def _load_text_guardrail(self):
|
128 |
+
"""Load text safety classifier models.
|
129 |
+
|
130 |
+
Initializes models used for checking input prompts against safety policies.
|
131 |
+
Models are loaded from the specified guardrail directory.
|
132 |
+
"""
|
133 |
+
self.text_guardrail = guardrail_presets.create_text_guardrail_runner(
|
134 |
+
checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir)
|
135 |
+
)
|
136 |
+
|
137 |
+
def _load_video_guardrail(self):
|
138 |
+
"""Load video safety classifier models.
|
139 |
+
|
140 |
+
Initializes models used for validating generated video content against
|
141 |
+
safety policies. Models are loaded from the specified guardrail directory.
|
142 |
+
"""
|
143 |
+
self.video_guardrail = guardrail_presets.create_video_guardrail_runner(
|
144 |
+
checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir)
|
145 |
+
)
|
146 |
+
|
147 |
+
def _offload_network(self):
|
148 |
+
if self.model.model:
|
149 |
+
del self.model.model
|
150 |
+
self.model.model = None
|
151 |
+
gc.collect()
|
152 |
+
torch.cuda.empty_cache()
|
153 |
+
|
154 |
+
def _offload_tokenizer(self):
|
155 |
+
if self.model.tokenizer:
|
156 |
+
del self.model.tokenizer
|
157 |
+
self.model.tokenizer = None
|
158 |
+
gc.collect()
|
159 |
+
torch.cuda.empty_cache()
|
160 |
+
|
161 |
+
def _offload_guardrail_models(self):
|
162 |
+
"""Offload safety classifier models to reduce memory usage.
|
163 |
+
|
164 |
+
Moves safety models to CPU and clears GPU memory if they are no longer needed.
|
165 |
+
This helps manage memory when processing multiple inputs sequentially.
|
166 |
+
"""
|
167 |
+
if self.text_guardrail:
|
168 |
+
del self.text_guardrail
|
169 |
+
self.text_guardrail = None
|
170 |
+
if self.video_guardrail:
|
171 |
+
del self.video_guardrail
|
172 |
+
self.video_guardrail = None
|
173 |
+
gc.collect()
|
174 |
+
torch.cuda.empty_cache()
|
175 |
+
|
176 |
+
def _offload_text_encoder_model(self):
|
177 |
+
"""Offload T5 text encoder to reduce memory usage.
|
178 |
+
|
179 |
+
Moves the T5 encoder to CPU and clears GPU memory after text encoding is complete.
|
180 |
+
This helps manage memory when processing multiple inputs sequentially.
|
181 |
+
"""
|
182 |
+
if self.text_encoder:
|
183 |
+
del self.text_encoder
|
184 |
+
self.text_encoder = None
|
185 |
+
gc.collect()
|
186 |
+
torch.cuda.empty_cache()
|
187 |
+
|
188 |
+
def _run_model(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
189 |
+
"""Generate world latents using the model.
|
190 |
+
|
191 |
+
This abstract method must be implemented by subclasses to define their specific
|
192 |
+
generation process.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
*args: Variable positional arguments for model inference
|
196 |
+
**kwargs: Variable keyword arguments for model inference
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
torch.Tensor: Generated world representation tensor
|
200 |
+
"""
|
201 |
+
pass
|
202 |
+
|
203 |
+
def _run_model_with_offload(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
204 |
+
"""Generate world representation with memory management.
|
205 |
+
|
206 |
+
Handles loading the model before inference and offloading afterward if enabled.
|
207 |
+
This helps minimize GPU memory usage during inference.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
*args: Arguments passed to _run_model
|
211 |
+
**kwargs: Keyword arguments passed to _run_model
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
np.ndarray: Generated world representation as numpy array
|
215 |
+
"""
|
216 |
+
pass
|
217 |
+
|
218 |
+
def _run_guardrail_on_prompt(self, prompt: str) -> bool:
|
219 |
+
"""Check if prompt meets safety requirements.
|
220 |
+
|
221 |
+
Validates the input prompt against safety policies using loaded guardrail models.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
prompt: Raw text prompt to validate
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
bool: True if prompt passes all safety checks, False otherwise
|
228 |
+
"""
|
229 |
+
return guardrail_presets.run_text_guardrail(prompt, self.text_guardrail)
|
230 |
+
|
231 |
+
def _run_guardrail_on_prompt_with_offload(self, prompt: str) -> bool:
|
232 |
+
"""Check prompt safety with memory management.
|
233 |
+
|
234 |
+
Validates prompt safety while handling model loading/offloading to manage memory.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
prompt: Raw text prompt to validate
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
bool: True if prompt passes all safety checks, False otherwise
|
241 |
+
"""
|
242 |
+
if self.offload_guardrail_models:
|
243 |
+
self._load_text_guardrail()
|
244 |
+
|
245 |
+
is_safe = self._run_guardrail_on_prompt(prompt)
|
246 |
+
|
247 |
+
if self.offload_guardrail_models:
|
248 |
+
self._offload_guardrail_models()
|
249 |
+
|
250 |
+
return is_safe
|
251 |
+
|
252 |
+
def _run_guardrail_on_video(self, video: np.ndarray) -> np.ndarray | None:
|
253 |
+
"""Check if video meets safety requirements.
|
254 |
+
|
255 |
+
Validates generated video content against safety policies using guardrail models.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
video: Video frames to validate
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
np.ndarray: Processed video if safe, None if unsafe
|
262 |
+
"""
|
263 |
+
return guardrail_presets.run_video_guardrail(video, self.video_guardrail)
|
264 |
+
|
265 |
+
def _run_guardrail_on_video_with_offload(self, video: np.ndarray) -> np.ndarray | None:
|
266 |
+
"""Check if generated video meets safety requirements.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
video: Video frames to validate
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
np.ndarray: Processed video frames if safe, None otherwise
|
273 |
+
|
274 |
+
Note:
|
275 |
+
Guardrail models are offloaded after checks if enabled.
|
276 |
+
"""
|
277 |
+
if self.offload_guardrail_models:
|
278 |
+
self._load_video_guardrail()
|
279 |
+
|
280 |
+
video = self._run_guardrail_on_video(video)
|
281 |
+
|
282 |
+
if self.offload_guardrail_models:
|
283 |
+
self._offload_guardrail_models()
|
284 |
+
return video
|
285 |
+
|
286 |
+
def _run_text_embedding_on_prompt(
|
287 |
+
self, prompts: list[str], **kwargs: Any
|
288 |
+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
289 |
+
"""Convert text prompts to embeddings.
|
290 |
+
|
291 |
+
Processes text prompts into embedding tensors that condition the generation model.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
prompts: List of text prompts to encode
|
295 |
+
**kwargs: Additional arguments for text encoding
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
tuple containing:
|
299 |
+
- List of text embedding tensors for each prompt
|
300 |
+
- List of attention masks for each embedding
|
301 |
+
"""
|
302 |
+
|
303 |
+
embeddings = []
|
304 |
+
masks = []
|
305 |
+
for prompt in prompts:
|
306 |
+
embedding, mask = self.text_encoder.encode_prompts(
|
307 |
+
[prompt],
|
308 |
+
**kwargs,
|
309 |
+
)
|
310 |
+
embeddings.append(embedding)
|
311 |
+
masks.append(mask)
|
312 |
+
|
313 |
+
return embeddings, masks
|
314 |
+
|
315 |
+
def _run_text_embedding_on_prompt_with_offload(
|
316 |
+
self, prompts: list[str], **kwargs: Any
|
317 |
+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
318 |
+
"""Convert text prompt into embeddings using T5 encoder.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
prompt: Processed and validated text prompt
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
Text embedding tensor to condition diffusion model
|
325 |
+
|
326 |
+
Note:
|
327 |
+
T5 model is offloaded after encoding if enabled.
|
328 |
+
"""
|
329 |
+
if self.offload_text_encoder_model:
|
330 |
+
self._load_text_encoder_model()
|
331 |
+
|
332 |
+
embeddings, masks = self._run_text_embedding_on_prompt(prompts, **kwargs)
|
333 |
+
|
334 |
+
if self.offload_text_encoder_model:
|
335 |
+
self._offload_text_encoder_model()
|
336 |
+
return embeddings, masks
|
337 |
+
|
338 |
+
def _run_tokenizer_decoding(self, samples: torch.Tensor) -> np.ndarray:
|
339 |
+
"""Decode model outputs into final world representation.
|
340 |
+
|
341 |
+
This abstract method must be implemented by subclasses to convert raw model
|
342 |
+
outputs into their specific world representation format.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
samples: Raw output tensor from the generation model
|
346 |
+
|
347 |
+
Returns:
|
348 |
+
np.ndarray: Decoded world representation
|
349 |
+
"""
|
350 |
+
pass
|
351 |
+
|
352 |
+
def generate(self, *args: Any, **kwargs: Any):
|
353 |
+
"""Generate world representation.
|
354 |
+
|
355 |
+
This abstract method must be implemented by subclasses to convert raw model
|
356 |
+
outputs into their specific world representation format.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
*args: Variable positional arguments for model inference
|
360 |
+
**kwargs: Variable keyword arguments for model inference
|
361 |
+
"""
|
362 |
+
pass
|
batch_ops.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Functions for performing operations with broadcasting to the right axis
|
17 |
+
#
|
18 |
+
# Example
|
19 |
+
# input1: tensor of size (N1, N2)
|
20 |
+
# input2: tensor of size (N1, N2, N3, N4)
|
21 |
+
# batch_mul(input1, input2) = input1[:, :, None, None] * input2
|
22 |
+
#
|
23 |
+
# If the common dimensions don't match, we raise an assertion error.
|
24 |
+
|
25 |
+
from torch import Tensor
|
26 |
+
|
27 |
+
|
28 |
+
def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
|
29 |
+
ndims1 = x.ndim
|
30 |
+
ndims2 = y.ndim
|
31 |
+
|
32 |
+
common_ndims = min(ndims1, ndims2)
|
33 |
+
for axis in range(common_ndims):
|
34 |
+
assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis)
|
35 |
+
|
36 |
+
if ndims1 < ndims2:
|
37 |
+
x = x.reshape(x.shape + (1,) * (ndims2 - ndims1))
|
38 |
+
elif ndims2 < ndims1:
|
39 |
+
y = y.reshape(y.shape + (1,) * (ndims1 - ndims2))
|
40 |
+
|
41 |
+
return x, y
|
42 |
+
|
43 |
+
|
44 |
+
def batch_mul(x: Tensor, y: Tensor) -> Tensor:
|
45 |
+
x, y = common_broadcast(x, y)
|
46 |
+
return x * y
|
blocklist.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
import string
|
20 |
+
from difflib import SequenceMatcher
|
21 |
+
|
22 |
+
from .log import log
|
23 |
+
import nltk
|
24 |
+
from better_profanity import profanity
|
25 |
+
|
26 |
+
from .guardrail_blocklist_utils import read_keyword_list_from_dir, to_ascii
|
27 |
+
from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
|
28 |
+
from .misc import misc, Color, timer
|
29 |
+
|
30 |
+
DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/blocklist"
|
31 |
+
CENSOR = Color.red("*")
|
32 |
+
|
33 |
+
|
34 |
+
class Blocklist(ContentSafetyGuardrail):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR,
|
38 |
+
guardrail_partial_match_min_chars: int = 4,
|
39 |
+
guardrail_partial_match_letter_count: float = 0.5,
|
40 |
+
) -> None:
|
41 |
+
nltk.data.path.append(os.path.join(checkpoint_dir, "nltk_data"))
|
42 |
+
self.lemmatizer = nltk.WordNetLemmatizer()
|
43 |
+
self.profanity = profanity
|
44 |
+
self.checkpoint_dir = checkpoint_dir
|
45 |
+
self.guardrail_partial_match_min_chars = guardrail_partial_match_min_chars
|
46 |
+
self.guardrail_partial_match_letter_count = guardrail_partial_match_letter_count
|
47 |
+
|
48 |
+
# Load blocklist and whitelist keywords
|
49 |
+
self.blocklist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "custom"))
|
50 |
+
self.whitelist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "whitelist"))
|
51 |
+
self.exact_match_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "exact_match"))
|
52 |
+
|
53 |
+
self.profanity.load_censor_words(custom_words=self.blocklist_words, whitelist_words=self.whitelist_words)
|
54 |
+
log.debug(f"Loaded {len(self.blocklist_words)} words/phrases from blocklist")
|
55 |
+
log.debug(f"Whitelisted {len(self.whitelist_words)} words/phrases from whitelist")
|
56 |
+
log.debug(f"Loaded {len(self.exact_match_words)} exact match words/phrases from blocklist")
|
57 |
+
|
58 |
+
def uncensor_whitelist(self, input_prompt: str, censored_prompt: str) -> str:
|
59 |
+
"""Explicitly uncensor words that are in the whitelist."""
|
60 |
+
input_words = input_prompt.split()
|
61 |
+
censored_words = censored_prompt.split()
|
62 |
+
whitelist_words = set(self.whitelist_words)
|
63 |
+
for i, token in enumerate(input_words):
|
64 |
+
if token.strip(string.punctuation).lower() in whitelist_words:
|
65 |
+
censored_words[i] = token
|
66 |
+
censored_prompt = " ".join(censored_words)
|
67 |
+
return censored_prompt
|
68 |
+
|
69 |
+
def censor_prompt(self, input_prompt: str) -> tuple[bool, str]:
|
70 |
+
"""Censor the prompt using the blocklist with better-profanity fuzzy matching.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
input_prompt: input prompt to censor
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
bool: True if the prompt is blocked, False otherwise
|
77 |
+
str: A message indicating why the prompt was blocked
|
78 |
+
"""
|
79 |
+
censored_prompt = self.profanity.censor(input_prompt, censor_char=CENSOR)
|
80 |
+
# Uncensor whitelisted words that were censored from blocklist fuzzy matching
|
81 |
+
censored_prompt = self.uncensor_whitelist(input_prompt, censored_prompt)
|
82 |
+
if CENSOR in censored_prompt:
|
83 |
+
return True, f"Prompt blocked by censorship: Censored Prompt: {censored_prompt}"
|
84 |
+
return False, ""
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def check_partial_match(
|
88 |
+
normalized_prompt: str, normalized_word: str, guardrail_partial_match_letter_count: float
|
89 |
+
) -> tuple[bool, str]:
|
90 |
+
"""
|
91 |
+
Check robustly if normalized word and the matching target have a difference of up to guardrail_partial_match_letter_count characters.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
normalized_prompt: a string with many words
|
95 |
+
normalized_word: a string with one or multiple words, its length is smaller than normalized_prompt
|
96 |
+
guardrail_partial_match_letter_count: maximum allowed difference in characters (float to allow partial characters)
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
bool: True if a match is found, False otherwise
|
100 |
+
str: A message indicating why the prompt was blocked
|
101 |
+
"""
|
102 |
+
prompt_words = normalized_prompt.split()
|
103 |
+
word_length = len(normalized_word.split())
|
104 |
+
max_similarity_ratio = (len(normalized_word) - float(guardrail_partial_match_letter_count)) / float(
|
105 |
+
len(normalized_word)
|
106 |
+
)
|
107 |
+
|
108 |
+
for i in range(len(prompt_words) - word_length + 1):
|
109 |
+
# Extract a substring from the prompt with the same number of words as the normalized_word
|
110 |
+
substring = " ".join(prompt_words[i : i + word_length])
|
111 |
+
similarity_ratio = SequenceMatcher(None, substring, normalized_word).ratio()
|
112 |
+
if similarity_ratio >= max_similarity_ratio:
|
113 |
+
return (
|
114 |
+
True,
|
115 |
+
f"Prompt blocked by partial match blocklist: Prompt: {normalized_prompt}, Partial Match Word: {normalized_word}",
|
116 |
+
)
|
117 |
+
|
118 |
+
return False, ""
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def check_against_whole_word_blocklist(
|
122 |
+
prompt: str,
|
123 |
+
blocklist: list[str],
|
124 |
+
guardrail_partial_match_min_chars: int = 4,
|
125 |
+
guardrail_partial_match_letter_count: float = 0.5,
|
126 |
+
) -> bool:
|
127 |
+
"""
|
128 |
+
Check if the prompt contains any whole words from the blocklist.
|
129 |
+
The match is case insensitive and robust to multiple spaces between words.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
prompt: input prompt to check
|
133 |
+
blocklist: list of words to check against
|
134 |
+
guardrail_partial_match_min_chars: minimum number of characters in a word to check for partial match
|
135 |
+
guardrail_partial_match_letter_count: maximum allowed difference in characters for partial match
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
bool: True if a match is found, False otherwise
|
139 |
+
str: A message indicating why the prompt was blocked
|
140 |
+
"""
|
141 |
+
# Normalize spaces and convert to lowercase
|
142 |
+
normalized_prompt = re.sub(r"\s+", " ", prompt).strip().lower()
|
143 |
+
|
144 |
+
for word in blocklist:
|
145 |
+
# Normalize spaces and convert to lowercase for each blocklist word
|
146 |
+
normalized_word = re.sub(r"\s+", " ", word).strip().lower()
|
147 |
+
|
148 |
+
# Use word boundaries to ensure whole word match
|
149 |
+
if re.search(r"\b" + re.escape(normalized_word) + r"\b", normalized_prompt):
|
150 |
+
return True, f"Prompt blocked by exact match blocklist: Prompt: {prompt}, Exact Match Word: {word}"
|
151 |
+
|
152 |
+
# Check for partial match if the word is long enough
|
153 |
+
if len(normalized_word) >= guardrail_partial_match_min_chars:
|
154 |
+
match, message = Blocklist.check_partial_match(
|
155 |
+
normalized_prompt, normalized_word, guardrail_partial_match_letter_count
|
156 |
+
)
|
157 |
+
if match:
|
158 |
+
return True, message
|
159 |
+
|
160 |
+
return False, ""
|
161 |
+
|
162 |
+
def is_safe(self, input_prompt: str = "") -> tuple[bool, str]:
|
163 |
+
"""Check if the input prompt is safe using the blocklist."""
|
164 |
+
# Check if the input is empty
|
165 |
+
if not input_prompt:
|
166 |
+
return False, "Input is empty"
|
167 |
+
input_prompt = to_ascii(input_prompt)
|
168 |
+
|
169 |
+
# Check full sentence for censored words
|
170 |
+
censored, message = self.censor_prompt(input_prompt)
|
171 |
+
if censored:
|
172 |
+
return False, message
|
173 |
+
|
174 |
+
# Check lemmatized words for censored words
|
175 |
+
tokens = nltk.word_tokenize(input_prompt)
|
176 |
+
lemmas = [self.lemmatizer.lemmatize(token) for token in tokens]
|
177 |
+
lemmatized_prompt = " ".join(lemmas)
|
178 |
+
censored, message = self.censor_prompt(lemmatized_prompt)
|
179 |
+
if censored:
|
180 |
+
return False, message
|
181 |
+
|
182 |
+
# Check for exact match blocklist words
|
183 |
+
censored, message = self.check_against_whole_word_blocklist(
|
184 |
+
input_prompt,
|
185 |
+
self.exact_match_words,
|
186 |
+
self.guardrail_partial_match_min_chars,
|
187 |
+
self.guardrail_partial_match_letter_count,
|
188 |
+
)
|
189 |
+
if censored:
|
190 |
+
return False, message
|
191 |
+
|
192 |
+
# If all these checks pass, the input is safe
|
193 |
+
return True, "Input is safe"
|
194 |
+
|
195 |
+
|
196 |
+
def parse_args():
|
197 |
+
parser = argparse.ArgumentParser()
|
198 |
+
parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
|
199 |
+
parser.add_argument(
|
200 |
+
"--checkpoint_dir",
|
201 |
+
type=str,
|
202 |
+
help="Path to the Blocklist checkpoint folder",
|
203 |
+
default=DEFAULT_CHECKPOINT_DIR,
|
204 |
+
)
|
205 |
+
return parser.parse_args()
|
206 |
+
|
207 |
+
|
208 |
+
def main(args):
|
209 |
+
blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir)
|
210 |
+
runner = GuardrailRunner(safety_models=[blocklist])
|
211 |
+
with timer("blocklist safety check"):
|
212 |
+
safety, message = runner.run_safety_check(args.prompt)
|
213 |
+
log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
|
214 |
+
log.info(f"Message: {message}") if not safety else None
|
215 |
+
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
args = parse_args()
|
219 |
+
main(args)
|
blocks.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from einops import rearrange, repeat
|
22 |
+
from einops.layers.torch import Rearrange
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
from .attention import Attention, GPT2FeedForward
|
26 |
+
from .log import log
|
27 |
+
|
28 |
+
|
29 |
+
def modulate(x, shift, scale):
|
30 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
31 |
+
|
32 |
+
|
33 |
+
class Timesteps(nn.Module):
|
34 |
+
def __init__(self, num_channels):
|
35 |
+
super().__init__()
|
36 |
+
self.num_channels = num_channels
|
37 |
+
|
38 |
+
def forward(self, timesteps):
|
39 |
+
in_dype = timesteps.dtype
|
40 |
+
half_dim = self.num_channels // 2
|
41 |
+
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
42 |
+
exponent = exponent / (half_dim - 0.0)
|
43 |
+
|
44 |
+
emb = torch.exp(exponent)
|
45 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
46 |
+
|
47 |
+
sin_emb = torch.sin(emb)
|
48 |
+
cos_emb = torch.cos(emb)
|
49 |
+
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
50 |
+
|
51 |
+
return emb.to(in_dype)
|
52 |
+
|
53 |
+
|
54 |
+
class TimestepEmbedding(nn.Module):
|
55 |
+
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False):
|
56 |
+
super().__init__()
|
57 |
+
log.debug(
|
58 |
+
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
59 |
+
)
|
60 |
+
self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora)
|
61 |
+
self.activation = nn.SiLU()
|
62 |
+
self.use_adaln_lora = use_adaln_lora
|
63 |
+
if use_adaln_lora:
|
64 |
+
self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
|
65 |
+
else:
|
66 |
+
self.linear_2 = nn.Linear(out_features, out_features, bias=True)
|
67 |
+
|
68 |
+
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
69 |
+
emb = self.linear_1(sample)
|
70 |
+
emb = self.activation(emb)
|
71 |
+
emb = self.linear_2(emb)
|
72 |
+
|
73 |
+
if self.use_adaln_lora:
|
74 |
+
adaln_lora_B_3D = emb
|
75 |
+
emb_B_D = sample
|
76 |
+
else:
|
77 |
+
emb_B_D = emb
|
78 |
+
adaln_lora_B_3D = None
|
79 |
+
|
80 |
+
return emb_B_D, adaln_lora_B_3D
|
81 |
+
|
82 |
+
|
83 |
+
class FourierFeatures(nn.Module):
|
84 |
+
"""
|
85 |
+
Implements a layer that generates Fourier features from input tensors, based on randomly sampled
|
86 |
+
frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
|
87 |
+
|
88 |
+
[B] -> [B, D]
|
89 |
+
|
90 |
+
Parameters:
|
91 |
+
num_channels (int): The number of Fourier features to generate.
|
92 |
+
bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
|
93 |
+
normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
|
94 |
+
the variance of the features. Defaults to False.
|
95 |
+
|
96 |
+
Example:
|
97 |
+
>>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
|
98 |
+
>>> x = torch.randn(10, 256) # Example input tensor
|
99 |
+
>>> output = layer(x)
|
100 |
+
>>> print(output.shape) # Expected shape: (10, 256)
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, num_channels, bandwidth=1, normalize=False):
|
104 |
+
super().__init__()
|
105 |
+
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
|
106 |
+
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
|
107 |
+
self.gain = np.sqrt(2) if normalize else 1
|
108 |
+
|
109 |
+
def forward(self, x, gain: float = 1.0):
|
110 |
+
"""
|
111 |
+
Apply the Fourier feature transformation to the input tensor.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
x (torch.Tensor): The input tensor.
|
115 |
+
gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: The transformed tensor, with Fourier features applied.
|
119 |
+
"""
|
120 |
+
in_dtype = x.dtype
|
121 |
+
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
|
122 |
+
x = x.cos().mul(self.gain * gain).to(in_dtype)
|
123 |
+
return x
|
124 |
+
|
125 |
+
|
126 |
+
class PatchEmbed(nn.Module):
|
127 |
+
"""
|
128 |
+
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
129 |
+
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
130 |
+
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
131 |
+
and embedding each patch into a vector of size `out_channels`.
|
132 |
+
|
133 |
+
Parameters:
|
134 |
+
- spatial_patch_size (int): The size of each spatial patch.
|
135 |
+
- temporal_patch_size (int): The size of each temporal patch.
|
136 |
+
- in_channels (int): Number of input channels. Default: 3.
|
137 |
+
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
138 |
+
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
spatial_patch_size,
|
144 |
+
temporal_patch_size,
|
145 |
+
in_channels=3,
|
146 |
+
out_channels=768,
|
147 |
+
bias=True,
|
148 |
+
):
|
149 |
+
super().__init__()
|
150 |
+
self.spatial_patch_size = spatial_patch_size
|
151 |
+
self.temporal_patch_size = temporal_patch_size
|
152 |
+
|
153 |
+
self.proj = nn.Sequential(
|
154 |
+
Rearrange(
|
155 |
+
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
156 |
+
r=temporal_patch_size,
|
157 |
+
m=spatial_patch_size,
|
158 |
+
n=spatial_patch_size,
|
159 |
+
),
|
160 |
+
nn.Linear(
|
161 |
+
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias
|
162 |
+
),
|
163 |
+
)
|
164 |
+
self.out = nn.Identity()
|
165 |
+
|
166 |
+
def forward(self, x):
|
167 |
+
"""
|
168 |
+
Forward pass of the PatchEmbed module.
|
169 |
+
|
170 |
+
Parameters:
|
171 |
+
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
172 |
+
B is the batch size,
|
173 |
+
C is the number of channels,
|
174 |
+
T is the temporal dimension,
|
175 |
+
H is the height, and
|
176 |
+
W is the width of the input.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
180 |
+
"""
|
181 |
+
assert x.dim() == 5
|
182 |
+
_, _, T, H, W = x.shape
|
183 |
+
assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
184 |
+
assert T % self.temporal_patch_size == 0
|
185 |
+
x = self.proj(x)
|
186 |
+
return self.out(x)
|
187 |
+
|
188 |
+
|
189 |
+
class FinalLayer(nn.Module):
|
190 |
+
"""
|
191 |
+
The final layer of video DiT.
|
192 |
+
"""
|
193 |
+
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
hidden_size,
|
197 |
+
spatial_patch_size,
|
198 |
+
temporal_patch_size,
|
199 |
+
out_channels,
|
200 |
+
use_adaln_lora: bool = False,
|
201 |
+
adaln_lora_dim: int = 256,
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
205 |
+
self.linear = nn.Linear(
|
206 |
+
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False
|
207 |
+
)
|
208 |
+
self.hidden_size = hidden_size
|
209 |
+
self.n_adaln_chunks = 2
|
210 |
+
self.use_adaln_lora = use_adaln_lora
|
211 |
+
if use_adaln_lora:
|
212 |
+
self.adaLN_modulation = nn.Sequential(
|
213 |
+
nn.SiLU(),
|
214 |
+
nn.Linear(hidden_size, adaln_lora_dim, bias=False),
|
215 |
+
nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False),
|
216 |
+
)
|
217 |
+
else:
|
218 |
+
self.adaLN_modulation = nn.Sequential(
|
219 |
+
nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False)
|
220 |
+
)
|
221 |
+
|
222 |
+
def forward(
|
223 |
+
self,
|
224 |
+
x_BT_HW_D,
|
225 |
+
emb_B_D,
|
226 |
+
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
227 |
+
):
|
228 |
+
if self.use_adaln_lora:
|
229 |
+
assert adaln_lora_B_3D is not None
|
230 |
+
shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk(
|
231 |
+
2, dim=1
|
232 |
+
)
|
233 |
+
else:
|
234 |
+
shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1)
|
235 |
+
|
236 |
+
B = emb_B_D.shape[0]
|
237 |
+
T = x_BT_HW_D.shape[0] // B
|
238 |
+
shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T)
|
239 |
+
x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D)
|
240 |
+
|
241 |
+
x_BT_HW_D = self.linear(x_BT_HW_D)
|
242 |
+
return x_BT_HW_D
|
243 |
+
|
244 |
+
|
245 |
+
class VideoAttn(nn.Module):
|
246 |
+
"""
|
247 |
+
Implements video attention with optional cross-attention capabilities.
|
248 |
+
|
249 |
+
This module processes video features while maintaining their spatio-temporal structure. It can perform
|
250 |
+
self-attention within the video features or cross-attention with external context features.
|
251 |
+
|
252 |
+
Parameters:
|
253 |
+
x_dim (int): Dimension of input feature vectors
|
254 |
+
context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention
|
255 |
+
num_heads (int): Number of attention heads
|
256 |
+
bias (bool): Whether to include bias in attention projections. Default: False
|
257 |
+
qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head"
|
258 |
+
x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD"
|
259 |
+
|
260 |
+
Input shape:
|
261 |
+
- x: (T, H, W, B, D) video features
|
262 |
+
- context (optional): (M, B, D) context features for cross-attention
|
263 |
+
where:
|
264 |
+
T: temporal dimension
|
265 |
+
H: height
|
266 |
+
W: width
|
267 |
+
B: batch size
|
268 |
+
D: feature dimension
|
269 |
+
M: context sequence length
|
270 |
+
"""
|
271 |
+
|
272 |
+
def __init__(
|
273 |
+
self,
|
274 |
+
x_dim: int,
|
275 |
+
context_dim: Optional[int],
|
276 |
+
num_heads: int,
|
277 |
+
bias: bool = False,
|
278 |
+
qkv_norm_mode: str = "per_head",
|
279 |
+
x_format: str = "BTHWD",
|
280 |
+
) -> None:
|
281 |
+
super().__init__()
|
282 |
+
self.x_format = x_format
|
283 |
+
|
284 |
+
self.attn = Attention(
|
285 |
+
x_dim,
|
286 |
+
context_dim,
|
287 |
+
num_heads,
|
288 |
+
x_dim // num_heads,
|
289 |
+
qkv_bias=bias,
|
290 |
+
qkv_norm="RRI",
|
291 |
+
out_bias=bias,
|
292 |
+
qkv_norm_mode=qkv_norm_mode,
|
293 |
+
qkv_format="sbhd",
|
294 |
+
)
|
295 |
+
|
296 |
+
def forward(
|
297 |
+
self,
|
298 |
+
x: torch.Tensor,
|
299 |
+
context: Optional[torch.Tensor] = None,
|
300 |
+
crossattn_mask: Optional[torch.Tensor] = None,
|
301 |
+
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
302 |
+
) -> torch.Tensor:
|
303 |
+
"""
|
304 |
+
Forward pass for video attention.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data.
|
308 |
+
context (Tensor): Context tensor of shape (B, M, D) or (M, B, D),
|
309 |
+
where M is the sequence length of the context.
|
310 |
+
crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms.
|
311 |
+
rope_emb_L_1_1_D (Optional[Tensor]):
|
312 |
+
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
Tensor: The output tensor with applied attention, maintaining the input shape.
|
316 |
+
"""
|
317 |
+
|
318 |
+
x_T_H_W_B_D = x
|
319 |
+
context_M_B_D = context
|
320 |
+
T, H, W, B, D = x_T_H_W_B_D.shape
|
321 |
+
x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d")
|
322 |
+
x_THW_B_D = self.attn(
|
323 |
+
x_THW_B_D,
|
324 |
+
context_M_B_D,
|
325 |
+
crossattn_mask,
|
326 |
+
rope_emb=rope_emb_L_1_1_D,
|
327 |
+
)
|
328 |
+
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
329 |
+
return x_T_H_W_B_D
|
330 |
+
|
331 |
+
|
332 |
+
def adaln_norm_state(norm_state, x, scale, shift):
|
333 |
+
normalized = norm_state(x)
|
334 |
+
return normalized * (1 + scale) + shift
|
335 |
+
|
336 |
+
|
337 |
+
class DITBuildingBlock(nn.Module):
|
338 |
+
"""
|
339 |
+
A building block for the DiT (Diffusion Transformer) architecture that supports different types of
|
340 |
+
attention and MLP operations with adaptive layer normalization.
|
341 |
+
|
342 |
+
Parameters:
|
343 |
+
block_type (str): Type of block - one of:
|
344 |
+
- "cross_attn"/"ca": Cross-attention
|
345 |
+
- "full_attn"/"fa": Full self-attention
|
346 |
+
- "mlp"/"ff": MLP/feedforward block
|
347 |
+
x_dim (int): Dimension of input features
|
348 |
+
context_dim (Optional[int]): Dimension of context features for cross-attention
|
349 |
+
num_heads (int): Number of attention heads
|
350 |
+
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
351 |
+
bias (bool): Whether to use bias in layers. Default: False
|
352 |
+
mlp_dropout (float): Dropout rate for MLP. Default: 0.0
|
353 |
+
qkv_norm_mode (str): QKV normalization mode. Default: "per_head"
|
354 |
+
x_format (str): Input tensor format. Default: "BTHWD"
|
355 |
+
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
356 |
+
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
357 |
+
"""
|
358 |
+
|
359 |
+
def __init__(
|
360 |
+
self,
|
361 |
+
block_type: str,
|
362 |
+
x_dim: int,
|
363 |
+
context_dim: Optional[int],
|
364 |
+
num_heads: int,
|
365 |
+
mlp_ratio: float = 4.0,
|
366 |
+
bias: bool = False,
|
367 |
+
mlp_dropout: float = 0.0,
|
368 |
+
qkv_norm_mode: str = "per_head",
|
369 |
+
x_format: str = "BTHWD",
|
370 |
+
use_adaln_lora: bool = False,
|
371 |
+
adaln_lora_dim: int = 256,
|
372 |
+
) -> None:
|
373 |
+
block_type = block_type.lower()
|
374 |
+
|
375 |
+
super().__init__()
|
376 |
+
self.x_format = x_format
|
377 |
+
if block_type in ["cross_attn", "ca"]:
|
378 |
+
self.block = VideoAttn(
|
379 |
+
x_dim,
|
380 |
+
context_dim,
|
381 |
+
num_heads,
|
382 |
+
bias=bias,
|
383 |
+
qkv_norm_mode=qkv_norm_mode,
|
384 |
+
x_format=self.x_format,
|
385 |
+
)
|
386 |
+
elif block_type in ["full_attn", "fa"]:
|
387 |
+
self.block = VideoAttn(
|
388 |
+
x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format
|
389 |
+
)
|
390 |
+
elif block_type in ["mlp", "ff"]:
|
391 |
+
self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias)
|
392 |
+
else:
|
393 |
+
raise ValueError(f"Unknown block type: {block_type}")
|
394 |
+
|
395 |
+
self.block_type = block_type
|
396 |
+
self.use_adaln_lora = use_adaln_lora
|
397 |
+
|
398 |
+
self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
|
399 |
+
self.n_adaln_chunks = 3
|
400 |
+
if use_adaln_lora:
|
401 |
+
self.adaLN_modulation = nn.Sequential(
|
402 |
+
nn.SiLU(),
|
403 |
+
nn.Linear(x_dim, adaln_lora_dim, bias=False),
|
404 |
+
nn.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False),
|
405 |
+
)
|
406 |
+
else:
|
407 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False))
|
408 |
+
|
409 |
+
def forward(
|
410 |
+
self,
|
411 |
+
x: torch.Tensor,
|
412 |
+
emb_B_D: torch.Tensor,
|
413 |
+
crossattn_emb: torch.Tensor,
|
414 |
+
crossattn_mask: Optional[torch.Tensor] = None,
|
415 |
+
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
416 |
+
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
417 |
+
) -> torch.Tensor:
|
418 |
+
"""
|
419 |
+
Forward pass for dynamically configured blocks with adaptive normalization.
|
420 |
+
|
421 |
+
Args:
|
422 |
+
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D).
|
423 |
+
emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation.
|
424 |
+
crossattn_emb (Tensor): Tensor for cross-attention blocks.
|
425 |
+
crossattn_mask (Optional[Tensor]): Optional mask for cross-attention.
|
426 |
+
rope_emb_L_1_1_D (Optional[Tensor]):
|
427 |
+
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
428 |
+
|
429 |
+
Returns:
|
430 |
+
Tensor: The output tensor after processing through the configured block and adaptive normalization.
|
431 |
+
"""
|
432 |
+
if self.use_adaln_lora:
|
433 |
+
shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk(
|
434 |
+
self.n_adaln_chunks, dim=1
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1)
|
438 |
+
|
439 |
+
shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = (
|
440 |
+
shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
441 |
+
scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
442 |
+
gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
443 |
+
)
|
444 |
+
|
445 |
+
if self.block_type in ["mlp", "ff"]:
|
446 |
+
x = x + gate_1_1_1_B_D * self.block(
|
447 |
+
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
448 |
+
)
|
449 |
+
elif self.block_type in ["full_attn", "fa"]:
|
450 |
+
x = x + gate_1_1_1_B_D * self.block(
|
451 |
+
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
452 |
+
context=None,
|
453 |
+
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
454 |
+
)
|
455 |
+
elif self.block_type in ["cross_attn", "ca"]:
|
456 |
+
x = x + gate_1_1_1_B_D * self.block(
|
457 |
+
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
458 |
+
context=crossattn_emb,
|
459 |
+
crossattn_mask=crossattn_mask,
|
460 |
+
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
461 |
+
)
|
462 |
+
else:
|
463 |
+
raise ValueError(f"Unknown block type: {self.block_type}")
|
464 |
+
|
465 |
+
return x
|
466 |
+
|
467 |
+
|
468 |
+
class GeneralDITTransformerBlock(nn.Module):
|
469 |
+
"""
|
470 |
+
A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer.
|
471 |
+
Each block in the sequence is specified by a block configuration string.
|
472 |
+
|
473 |
+
Parameters:
|
474 |
+
x_dim (int): Dimension of input features
|
475 |
+
context_dim (int): Dimension of context features for cross-attention blocks
|
476 |
+
num_heads (int): Number of attention heads
|
477 |
+
block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention,
|
478 |
+
full-attention, then MLP)
|
479 |
+
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
480 |
+
x_format (str): Input tensor format. Default: "BTHWD"
|
481 |
+
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
482 |
+
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
483 |
+
|
484 |
+
The block_config string uses "-" to separate block types:
|
485 |
+
- "ca"/"cross_attn": Cross-attention block
|
486 |
+
- "fa"/"full_attn": Full self-attention block
|
487 |
+
- "mlp"/"ff": MLP/feedforward block
|
488 |
+
|
489 |
+
Example:
|
490 |
+
block_config = "ca-fa-mlp" creates a sequence of:
|
491 |
+
1. Cross-attention block
|
492 |
+
2. Full self-attention block
|
493 |
+
3. MLP block
|
494 |
+
"""
|
495 |
+
|
496 |
+
def __init__(
|
497 |
+
self,
|
498 |
+
x_dim: int,
|
499 |
+
context_dim: int,
|
500 |
+
num_heads: int,
|
501 |
+
block_config: str,
|
502 |
+
mlp_ratio: float = 4.0,
|
503 |
+
x_format: str = "BTHWD",
|
504 |
+
use_adaln_lora: bool = False,
|
505 |
+
adaln_lora_dim: int = 256,
|
506 |
+
):
|
507 |
+
super().__init__()
|
508 |
+
self.blocks = nn.ModuleList()
|
509 |
+
self.x_format = x_format
|
510 |
+
for block_type in block_config.split("-"):
|
511 |
+
self.blocks.append(
|
512 |
+
DITBuildingBlock(
|
513 |
+
block_type,
|
514 |
+
x_dim,
|
515 |
+
context_dim,
|
516 |
+
num_heads,
|
517 |
+
mlp_ratio,
|
518 |
+
x_format=self.x_format,
|
519 |
+
use_adaln_lora=use_adaln_lora,
|
520 |
+
adaln_lora_dim=adaln_lora_dim,
|
521 |
+
)
|
522 |
+
)
|
523 |
+
|
524 |
+
def forward(
|
525 |
+
self,
|
526 |
+
x: torch.Tensor,
|
527 |
+
emb_B_D: torch.Tensor,
|
528 |
+
crossattn_emb: torch.Tensor,
|
529 |
+
crossattn_mask: Optional[torch.Tensor] = None,
|
530 |
+
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
531 |
+
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
532 |
+
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
533 |
+
) -> torch.Tensor:
|
534 |
+
if extra_per_block_pos_emb is not None:
|
535 |
+
x = x + extra_per_block_pos_emb
|
536 |
+
for block in self.blocks:
|
537 |
+
x = block(
|
538 |
+
x,
|
539 |
+
emb_B_D,
|
540 |
+
crossattn_emb,
|
541 |
+
crossattn_mask,
|
542 |
+
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
543 |
+
adaln_lora_B_3D=adaln_lora_B_3D,
|
544 |
+
)
|
545 |
+
return x
|
blur_utils.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
|
20 |
+
def pixelate_face(face_img: np.ndarray, blocks: int = 5) -> np.ndarray:
|
21 |
+
"""
|
22 |
+
Pixelate a face region by reducing resolution and then upscaling.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
face_img: Face region to pixelate
|
26 |
+
blocks: Number of blocks to divide the face into (in each dimension)
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
Pixelated face region
|
30 |
+
"""
|
31 |
+
h, w = face_img.shape[:2]
|
32 |
+
# Shrink the image and scale back up to create pixelation effect
|
33 |
+
temp = cv2.resize(face_img, (blocks, blocks), interpolation=cv2.INTER_LINEAR)
|
34 |
+
pixelated = cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST)
|
35 |
+
return pixelated
|
categories.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
VIOLENCE = {
|
17 |
+
"category": "violence",
|
18 |
+
"blocked_reason": """This prompt is blocked due to the presence of violent content.""",
|
19 |
+
"instruction": """O1: Violence.
|
20 |
+
Should not
|
21 |
+
- Help people plan or engage in violence.
|
22 |
+
- Encourage people to engage in violence.
|
23 |
+
- Contain unpleasant vivid violent imagery or violent description that arouses discomfort in the reader.
|
24 |
+
Can
|
25 |
+
- Provide information on violence.
|
26 |
+
- Discuss the topics of violence, and can discuss historical events involving violence.
|
27 |
+
""",
|
28 |
+
}
|
29 |
+
|
30 |
+
SEXUAL = {
|
31 |
+
"category": "sexual",
|
32 |
+
"blocked_reason": """This prompt is blocked due to the presence of sexual content.""",
|
33 |
+
"instruction": """O2: Sexual.
|
34 |
+
Should not
|
35 |
+
- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
|
36 |
+
- Engage or encourage sexual activity that is considered illegal.
|
37 |
+
Can
|
38 |
+
- Discuss topics related to sexuality and sexual education.
|
39 |
+
- Provide general information regarding sexual activities.
|
40 |
+
- Provide guidance on topics related to sexual health.
|
41 |
+
""",
|
42 |
+
}
|
43 |
+
|
44 |
+
CRIMINAL_PLANNING = {
|
45 |
+
"category": "criminal_planning",
|
46 |
+
"blocked_reason": """This prompt is blocked due to the presence of criminal planning/confessions content.""",
|
47 |
+
"instruction": """O3: Criminal Planning/Confessions.
|
48 |
+
Should not
|
49 |
+
- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
|
50 |
+
- Encourage people to engage in criminal activities or condone criminal behavior.
|
51 |
+
- Have a confession of crime that was committed or will be committed in future.
|
52 |
+
Can
|
53 |
+
- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
|
54 |
+
- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.
|
55 |
+
""",
|
56 |
+
}
|
57 |
+
|
58 |
+
WEAPONS = {
|
59 |
+
"category": "weapons",
|
60 |
+
"blocked_reason": """This prompt is blocked due to the presence of weapons content.""",
|
61 |
+
"instruction": """O4: Guns and Illegal Weapons.
|
62 |
+
Should not
|
63 |
+
- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
|
64 |
+
- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
|
65 |
+
- Assist or encourage people to execute or own other illegal weapons for example biological agents, or chemical weapons, including instructions on how to create or use them.
|
66 |
+
Can
|
67 |
+
- Discuss firearms and the arguments for and against firearm ownership.
|
68 |
+
""",
|
69 |
+
}
|
70 |
+
|
71 |
+
SUBSTANCE_ABUSE = {
|
72 |
+
"category": "substance_abuse",
|
73 |
+
"blocked_reason": """This prompt is blocked due to the presence of substance abuse content.""",
|
74 |
+
"instruction": """O5: Controlled/Regulated Substances.
|
75 |
+
Should not
|
76 |
+
- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
|
77 |
+
- Assist or encourage people to create such substances when it is illegal to do so.
|
78 |
+
Can
|
79 |
+
- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
|
80 |
+
- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.
|
81 |
+
""",
|
82 |
+
}
|
83 |
+
|
84 |
+
SUICIDE = {
|
85 |
+
"category": "suicide",
|
86 |
+
"blocked_reason": "This prompt is blocked due to the presence of suicide and self harm content.",
|
87 |
+
"instruction": """O6: Suicide and Self Harm.
|
88 |
+
Should not
|
89 |
+
- Encourage people to harm themselves.
|
90 |
+
- Romanticize or condone self-harm or suicide.
|
91 |
+
- Provide information regarding the methods of suicide or self-harm.
|
92 |
+
- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
|
93 |
+
- Contain confessions or intent of self harm
|
94 |
+
Can
|
95 |
+
- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
|
96 |
+
Should
|
97 |
+
- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.
|
98 |
+
""",
|
99 |
+
}
|
100 |
+
|
101 |
+
CSAM = {
|
102 |
+
"category": "csam",
|
103 |
+
"blocked_reason": "This prompt contains material related to Child Sexual Abuse",
|
104 |
+
"instruction": """O7: Sexual (minor).
|
105 |
+
Should not
|
106 |
+
- Assist, engage or encourage intention to engage in sexual activity with an underage/minor.
|
107 |
+
- Discuss or plan sexual acts, sharing of media of sexual nature, personal information or intention to meet in person for sexual activity with a minor.
|
108 |
+
- Should not include discussion where private information is shared with underage minors for purposes of acting on them and taking actions that are sexual in nature including meeting or sharing media.
|
109 |
+
- Should not discuss child abuse or engage, encourage, assist in sexual grooming/trafficking of children.
|
110 |
+
Can
|
111 |
+
- Provide information about sex education for use in schools.
|
112 |
+
""",
|
113 |
+
}
|
114 |
+
|
115 |
+
HATRED = {
|
116 |
+
"category": "hatred",
|
117 |
+
"blocked_reason": "This prompt is blocked due to the presence of hate speech content.",
|
118 |
+
"instruction": """O8: Hate/Identity Hate.
|
119 |
+
Should not
|
120 |
+
- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their
|
121 |
+
race, color, ethnicity, religion, national origin, sexual orientation, gender, gender identity, or disability or caste.
|
122 |
+
- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
|
123 |
+
- Should not discuss prejudice towards a group classified as protected according to the law of Germany, UK, USA or any other country.
|
124 |
+
- Should not contain slurs.
|
125 |
+
- Should not contain text that is distasteful, offensive, or hateful that causes others to feel grief, anger, or fear.
|
126 |
+
Can
|
127 |
+
- Provide information on discrimination.
|
128 |
+
- Discuss the topics of discrimination, and can discuss historical events involving discrimination.
|
129 |
+
- Discuss neutral statements referring to group identity and contextualized hate speech (that talks about hate but does not incite hate).
|
130 |
+
""",
|
131 |
+
}
|
132 |
+
|
133 |
+
HARASSMENT = {
|
134 |
+
"category": "harassment",
|
135 |
+
"blocked_reason": "This prompt is blocked due to the presence of harassment content.",
|
136 |
+
"instruction": """O10: Harassment.
|
137 |
+
Should not
|
138 |
+
- Have content that may be used to torment or annoy individuals in real life, or make harassment more likely to occur.
|
139 |
+
- Engage in harassment, misconduct and unlawful discrimination
|
140 |
+
- Single out an individual or marginalization and/or retaliation based on protected characteristics like Race, Color, Gender, Sex,
|
141 |
+
Sexual orientation, Gender identity, and gender expression, National origin, Ethnicity, Disability (including being regarded as disabled),
|
142 |
+
Religion, Age (40+), Pregnancy (including pregnancy, childbirth or related medical conditions), Genetic information, Military or veteran status, Citizenship status, Political activity, or affiliation,
|
143 |
+
- Taking or requesting statutorily protected leave Body characteristics.
|
144 |
+
- Promise to give a benefit, or a threat to retaliate or take an adverse action based on a response to harassment and misconduct.
|
145 |
+
- Bullying.
|
146 |
+
- Being derogatory and humiliating toward an individual but not necessarily protected characteristics under law.
|
147 |
+
- Making rude or insulting comments, demeaning, and objectifying terms toward an individual.
|
148 |
+
Can
|
149 |
+
- Discuss negative impacts of harassment and bullying based on protected characteristics like race, color, gender, sex, nationality, within a contextualized setting.
|
150 |
+
- Provide information on protected characteristics or refers, cites material that describes what protected characteristics are.
|
151 |
+
- Provide information or refer or cite material that informs what harassment is.
|
152 |
+
""",
|
153 |
+
}
|
154 |
+
|
155 |
+
THREAT = {
|
156 |
+
"category": "threat",
|
157 |
+
"blocked_reason": "This prompt is blocked due to the presence of threat content.",
|
158 |
+
"instruction": """O11: Threat.
|
159 |
+
Should not
|
160 |
+
- Discuss a communication of intent to inflict harm or loss on another person.
|
161 |
+
- Should not contain discussion of intention of intimidation which is a tactic used between conflicting parties to make the other timid or psychologically insecure for coercion or control.
|
162 |
+
- Should not contain intention of knowingly putting another person in fear of bodily injury.
|
163 |
+
- Should not assist or plan injury, physical or mental damage, act or instance of injury, or material or any kind of loss to a person.
|
164 |
+
- Should not contain threats that are forbidden by law as intent to obtain a monetary advantage or to compel a person to act against their will.
|
165 |
+
Can
|
166 |
+
- Provide information about threat is and the meaning of threat.
|
167 |
+
""",
|
168 |
+
}
|
169 |
+
|
170 |
+
PROFANITY = {
|
171 |
+
"category": "profanity",
|
172 |
+
"blocked_reason": "This prompt is blocked due to the presence of profanity content.",
|
173 |
+
"instruction": """O12: Profanity.
|
174 |
+
Should not
|
175 |
+
- Contain swear words, curse words, or other obscene or profane language
|
176 |
+
- This includes offensive words used without any intention to act on them or without any explicit targeted recipient
|
177 |
+
""",
|
178 |
+
}
|
179 |
+
|
180 |
+
UNSAFE_CATEGORIES = [
|
181 |
+
VIOLENCE,
|
182 |
+
SEXUAL,
|
183 |
+
CRIMINAL_PLANNING,
|
184 |
+
WEAPONS,
|
185 |
+
SUBSTANCE_ABUSE,
|
186 |
+
SUICIDE,
|
187 |
+
CSAM,
|
188 |
+
HATRED,
|
189 |
+
HARASSMENT,
|
190 |
+
THREAT,
|
191 |
+
PROFANITY,
|
192 |
+
]
|
checkpoint.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Dict, Optional
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
# Substrings to ignore when processing state dicts
|
21 |
+
substrings_to_ignore = [
|
22 |
+
"_extra_state", # Extra states (BytesIO type) added by TransformerEngine for FP8 handling
|
23 |
+
]
|
24 |
+
|
25 |
+
|
26 |
+
def get_partial_state_dict(
|
27 |
+
state_dict: Dict[str, torch.Tensor],
|
28 |
+
prefix: str,
|
29 |
+
) -> Dict[str, torch.Tensor]:
|
30 |
+
"""
|
31 |
+
Get a partial state dict with keys starting with the given prefix
|
32 |
+
"""
|
33 |
+
return {k: v for k, v in state_dict.items() if k.startswith(prefix)}
|
34 |
+
|
35 |
+
|
36 |
+
def process_state_dict(
|
37 |
+
state_dict: Dict[str, torch.Tensor],
|
38 |
+
device: str = None,
|
39 |
+
dtype: torch.dtype = None,
|
40 |
+
prefix_to_remove: Optional[str] = None,
|
41 |
+
) -> Dict[str, torch.Tensor]:
|
42 |
+
"""
|
43 |
+
- Remove items with substring "_extra_state" in keys (TransformerEngine adds these for FP8)
|
44 |
+
- Move tensors to specified device and dtype if provided
|
45 |
+
|
46 |
+
Args:
|
47 |
+
state_dict (Dict[str, torch.Tensor]): The state dict to process
|
48 |
+
device (str, optional): The device to move tensors to. Defaults to None.
|
49 |
+
dtype (torch.dtype, optional): The dtype to move tensors to. Defaults to None.
|
50 |
+
prefix_to_remove (str, optional): The prefix to remove from the keys of the state dict. Defaults to None.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Dict[str, torch.Tensor]: The processed state dict
|
54 |
+
"""
|
55 |
+
new_state_dict = {}
|
56 |
+
tensor_kwargs = {}
|
57 |
+
if device is not None:
|
58 |
+
tensor_kwargs["device"] = device
|
59 |
+
if dtype is not None:
|
60 |
+
tensor_kwargs["dtype"] = dtype
|
61 |
+
|
62 |
+
for key, value in state_dict.items():
|
63 |
+
# Check if any of the substrings to ignore are in the key
|
64 |
+
skip = False
|
65 |
+
for substr in substrings_to_ignore:
|
66 |
+
if substr in key:
|
67 |
+
skip = True
|
68 |
+
break
|
69 |
+
if skip:
|
70 |
+
continue
|
71 |
+
if len(tensor_kwargs) > 0:
|
72 |
+
value = value.to(**tensor_kwargs)
|
73 |
+
if prefix_to_remove is not None and key.startswith(prefix_to_remove):
|
74 |
+
key = key[len(prefix_to_remove) :]
|
75 |
+
new_state_dict[key] = value
|
76 |
+
return new_state_dict
|
conditioner.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import copy
|
17 |
+
from abc import ABC, abstractmethod
|
18 |
+
from collections import defaultdict
|
19 |
+
from dataclasses import dataclass, fields
|
20 |
+
from enum import Enum
|
21 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.nn as nn
|
25 |
+
|
26 |
+
from .batch_ops import batch_mul
|
27 |
+
from .log import log
|
28 |
+
from .lazy_config_init import instantiate
|
29 |
+
|
30 |
+
|
31 |
+
class BaseConditionEntry(nn.Module):
|
32 |
+
def __init__(self):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self._dropout_rate = None
|
36 |
+
self._input_key = None
|
37 |
+
self._return_dict = False
|
38 |
+
|
39 |
+
@property
|
40 |
+
def dropout_rate(self) -> Union[float, torch.Tensor]:
|
41 |
+
return self._dropout_rate
|
42 |
+
|
43 |
+
@property
|
44 |
+
def input_key(self) -> str:
|
45 |
+
return self._input_key
|
46 |
+
|
47 |
+
@property
|
48 |
+
def is_return_dict(self) -> bool:
|
49 |
+
return self._return_dict
|
50 |
+
|
51 |
+
@dropout_rate.setter
|
52 |
+
def dropout_rate(self, value: Union[float, torch.Tensor]):
|
53 |
+
self._dropout_rate = value
|
54 |
+
|
55 |
+
@input_key.setter
|
56 |
+
def input_key(self, value: str):
|
57 |
+
self._input_key = value
|
58 |
+
|
59 |
+
@is_return_dict.setter
|
60 |
+
def is_return_dict(self, value: bool):
|
61 |
+
self._return_dict = value
|
62 |
+
|
63 |
+
@dropout_rate.deleter
|
64 |
+
def dropout_rate(self):
|
65 |
+
del self._dropout_rate
|
66 |
+
|
67 |
+
@input_key.deleter
|
68 |
+
def input_key(self):
|
69 |
+
del self._input_key
|
70 |
+
|
71 |
+
@is_return_dict.deleter
|
72 |
+
def is_return_dict(self):
|
73 |
+
del self._return_dict
|
74 |
+
|
75 |
+
def random_dropout_input(
|
76 |
+
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None
|
77 |
+
) -> torch.Tensor:
|
78 |
+
del key
|
79 |
+
dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate
|
80 |
+
return batch_mul(
|
81 |
+
torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor),
|
82 |
+
in_tensor,
|
83 |
+
)
|
84 |
+
|
85 |
+
def summary(self) -> str:
|
86 |
+
pass
|
87 |
+
|
88 |
+
|
89 |
+
class DataType(Enum):
|
90 |
+
IMAGE = "image"
|
91 |
+
VIDEO = "video"
|
92 |
+
|
93 |
+
|
94 |
+
class TextAttr(BaseConditionEntry):
|
95 |
+
def __init__(self):
|
96 |
+
super().__init__()
|
97 |
+
|
98 |
+
def forward(self, token: torch.Tensor, mask: torch.Tensor):
|
99 |
+
return {"crossattn_emb": token, "crossattn_mask": mask}
|
100 |
+
|
101 |
+
def random_dropout_input(
|
102 |
+
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None
|
103 |
+
) -> torch.Tensor:
|
104 |
+
if key is not None and "mask" in key:
|
105 |
+
return in_tensor
|
106 |
+
return super().random_dropout_input(in_tensor, dropout_rate, key)
|
107 |
+
|
108 |
+
|
109 |
+
@dataclass
|
110 |
+
class BaseVideoCondition:
|
111 |
+
crossattn_emb: torch.Tensor
|
112 |
+
crossattn_mask: torch.Tensor
|
113 |
+
data_type: DataType = DataType.VIDEO
|
114 |
+
padding_mask: Optional[torch.Tensor] = None
|
115 |
+
fps: Optional[torch.Tensor] = None
|
116 |
+
num_frames: Optional[torch.Tensor] = None
|
117 |
+
image_size: Optional[torch.Tensor] = None
|
118 |
+
scalar_feature: Optional[torch.Tensor] = None
|
119 |
+
|
120 |
+
def to_dict(self) -> Dict[str, Optional[torch.Tensor]]:
|
121 |
+
return {f.name: getattr(self, f.name) for f in fields(self)}
|
122 |
+
|
123 |
+
|
124 |
+
@dataclass
|
125 |
+
class VideoExtendCondition(BaseVideoCondition):
|
126 |
+
video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video
|
127 |
+
gt_latent: Optional[torch.Tensor] = None
|
128 |
+
condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region
|
129 |
+
|
130 |
+
# condition_video_input_mask will concat to the input of network, along channel dim;
|
131 |
+
# Will be concat with the input tensor
|
132 |
+
condition_video_input_mask: Optional[torch.Tensor] = None
|
133 |
+
# condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed"
|
134 |
+
condition_video_augment_sigma: Optional[torch.Tensor] = None
|
135 |
+
|
136 |
+
|
137 |
+
class GeneralConditioner(nn.Module, ABC):
|
138 |
+
"""
|
139 |
+
An abstract module designed to handle various embedding models with conditional and
|
140 |
+
unconditional configurations. This abstract base class initializes and manages a collection
|
141 |
+
of embedders that can dynamically adjust their dropout rates based on conditioning.
|
142 |
+
|
143 |
+
Attributes:
|
144 |
+
KEY2DIM (dict): A mapping from output keys to dimensions used for concatenation.
|
145 |
+
embedders (nn.ModuleDict): A dictionary containing all embedded models initialized and
|
146 |
+
configured based on the provided configurations.
|
147 |
+
|
148 |
+
Parameters:
|
149 |
+
emb_models (Union[List, Any]): A dictionary where keys are embedder names and values
|
150 |
+
are configurations for initializing the embedders.
|
151 |
+
|
152 |
+
"""
|
153 |
+
|
154 |
+
KEY2DIM = {"crossattn_emb": 1, "crossattn_mask": 1}
|
155 |
+
|
156 |
+
def __init__(self, **emb_models: Union[List, Any]):
|
157 |
+
super().__init__()
|
158 |
+
self.embedders = nn.ModuleDict()
|
159 |
+
for n, (emb_name, embconfig) in enumerate(emb_models.items()):
|
160 |
+
embedder = instantiate(embconfig.obj)
|
161 |
+
assert isinstance(
|
162 |
+
embedder, BaseConditionEntry
|
163 |
+
), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
|
164 |
+
embedder.dropout_rate = getattr(embconfig, "dropout_rate", 0.0)
|
165 |
+
|
166 |
+
if hasattr(embconfig, "input_key"):
|
167 |
+
embedder.input_key = embconfig.input_key
|
168 |
+
elif hasattr(embconfig, "input_keys"):
|
169 |
+
embedder.input_keys = embconfig.input_keys
|
170 |
+
else:
|
171 |
+
raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}")
|
172 |
+
|
173 |
+
log.debug(f"Initialized embedder #{n}-{emb_name}: \n {embedder.summary()}")
|
174 |
+
self.embedders[emb_name] = embedder
|
175 |
+
|
176 |
+
@abstractmethod
|
177 |
+
def forward(
|
178 |
+
self,
|
179 |
+
batch: Dict,
|
180 |
+
override_dropout_rate: Optional[Dict[str, float]] = None,
|
181 |
+
) -> Any:
|
182 |
+
"""Should be implemented in subclasses to handle conditon datatype"""
|
183 |
+
raise NotImplementedError
|
184 |
+
|
185 |
+
def _forward(
|
186 |
+
self,
|
187 |
+
batch: Dict,
|
188 |
+
override_dropout_rate: Optional[Dict[str, float]] = None,
|
189 |
+
) -> Dict:
|
190 |
+
"""
|
191 |
+
Processes the input batch through all configured embedders, applying conditional dropout rates if specified.
|
192 |
+
Output tensors for each key are concatenated along the dimensions specified in KEY2DIM.
|
193 |
+
|
194 |
+
Parameters:
|
195 |
+
batch (Dict): The input data batch to process.
|
196 |
+
override_dropout_rate (Optional[Dict[str, float]]): Optional dictionary to override default dropout rates
|
197 |
+
per embedder key.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
Dict: A dictionary of output tensors concatenated by specified dimensions.
|
201 |
+
|
202 |
+
Note:
|
203 |
+
In case the network code is sensitive to the order of concatenation, you can either control the order via \
|
204 |
+
config file or make sure the embedders return a unique key for each output.
|
205 |
+
"""
|
206 |
+
output = defaultdict(list)
|
207 |
+
if override_dropout_rate is None:
|
208 |
+
override_dropout_rate = {}
|
209 |
+
|
210 |
+
# make sure emb_name in override_dropout_rate is valid
|
211 |
+
for emb_name in override_dropout_rate.keys():
|
212 |
+
assert emb_name in self.embedders, f"invalid name found {emb_name}"
|
213 |
+
|
214 |
+
for emb_name, embedder in self.embedders.items():
|
215 |
+
with torch.no_grad():
|
216 |
+
if hasattr(embedder, "input_key") and (embedder.input_key is not None):
|
217 |
+
emb_out = embedder(
|
218 |
+
embedder.random_dropout_input(
|
219 |
+
batch[embedder.input_key], override_dropout_rate.get(emb_name, None)
|
220 |
+
)
|
221 |
+
)
|
222 |
+
elif hasattr(embedder, "input_keys"):
|
223 |
+
emb_out = embedder(
|
224 |
+
*[
|
225 |
+
embedder.random_dropout_input(batch[k], override_dropout_rate.get(emb_name, None), k)
|
226 |
+
for k in embedder.input_keys
|
227 |
+
]
|
228 |
+
)
|
229 |
+
for k, v in emb_out.items():
|
230 |
+
output[k].append(v)
|
231 |
+
# Concatenate the outputs
|
232 |
+
return {k: torch.cat(v, dim=self.KEY2DIM.get(k, -1)) for k, v in output.items()}
|
233 |
+
|
234 |
+
def get_condition_uncondition(
|
235 |
+
self,
|
236 |
+
data_batch: Dict,
|
237 |
+
) -> Tuple[Any, Any]:
|
238 |
+
"""
|
239 |
+
Processes the provided data batch to generate conditioned and unconditioned outputs.
|
240 |
+
|
241 |
+
This method manipulates dropout rates to simulate two scenarios:
|
242 |
+
1. All conditions applied (conditioned)
|
243 |
+
2. Conditions removed/reduced to minimum (unconditioned)
|
244 |
+
|
245 |
+
This method sets dropout rates to zero for the conditioned scenario to fully apply
|
246 |
+
embedders' effects. For unconditioned, it sets rates to 1 (or 0 if initial rate is
|
247 |
+
insignificant) to minimize embedder influences.
|
248 |
+
|
249 |
+
Parameters:
|
250 |
+
data_batch (Dict): Input data batch containing all necessary information for
|
251 |
+
embedding processing.
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
Tuple[Any, Any]: A tuple containing:
|
255 |
+
- Outputs with all embedders fully applied (conditioned)
|
256 |
+
- Outputs with embedders minimized/not applied (unconditioned)
|
257 |
+
"""
|
258 |
+
cond_dropout_rates, dropout_rates = {}, {}
|
259 |
+
for emb_name, embedder in self.embedders.items():
|
260 |
+
cond_dropout_rates[emb_name] = 0.0
|
261 |
+
dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0
|
262 |
+
|
263 |
+
condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates)
|
264 |
+
un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates)
|
265 |
+
return condition, un_condition
|
266 |
+
|
267 |
+
def get_condition_with_negative_prompt(
|
268 |
+
self,
|
269 |
+
data_batch: Dict,
|
270 |
+
) -> Tuple[Any, Any]:
|
271 |
+
"""
|
272 |
+
Similar functionality as get_condition_uncondition
|
273 |
+
But use negative prompts for unconditon
|
274 |
+
"""
|
275 |
+
cond_dropout_rates, uncond_dropout_rates = {}, {}
|
276 |
+
for emb_name, embedder in self.embedders.items():
|
277 |
+
cond_dropout_rates[emb_name] = 0.0
|
278 |
+
if isinstance(embedder, TextAttr):
|
279 |
+
uncond_dropout_rates[emb_name] = 0.0
|
280 |
+
else:
|
281 |
+
uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0
|
282 |
+
|
283 |
+
data_batch_neg_prompt = copy.deepcopy(data_batch)
|
284 |
+
if "neg_t5_text_embeddings" in data_batch_neg_prompt:
|
285 |
+
if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor):
|
286 |
+
data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"]
|
287 |
+
data_batch_neg_prompt["t5_text_mask"] = data_batch_neg_prompt["neg_t5_text_mask"]
|
288 |
+
|
289 |
+
condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates)
|
290 |
+
un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates)
|
291 |
+
|
292 |
+
return condition, un_condition
|
293 |
+
|
294 |
+
|
295 |
+
@dataclass
|
296 |
+
class CosmosCondition:
|
297 |
+
crossattn_emb: torch.Tensor
|
298 |
+
crossattn_mask: torch.Tensor
|
299 |
+
padding_mask: Optional[torch.Tensor] = None
|
300 |
+
scalar_feature: Optional[torch.Tensor] = None
|
301 |
+
|
302 |
+
def to_dict(self) -> Dict[str, Optional[torch.Tensor]]:
|
303 |
+
return {f.name: getattr(self, f.name) for f in fields(self)}
|
304 |
+
|
305 |
+
|
306 |
+
class VideoConditioner(GeneralConditioner):
|
307 |
+
def forward(
|
308 |
+
self,
|
309 |
+
batch: Dict,
|
310 |
+
override_dropout_rate: Optional[Dict[str, float]] = None,
|
311 |
+
) -> BaseVideoCondition:
|
312 |
+
output = super()._forward(batch, override_dropout_rate)
|
313 |
+
return BaseVideoCondition(**output)
|
314 |
+
|
315 |
+
|
316 |
+
class VideoExtendConditioner(GeneralConditioner):
|
317 |
+
def forward(
|
318 |
+
self,
|
319 |
+
batch: Dict,
|
320 |
+
override_dropout_rate: Optional[Dict[str, float]] = None,
|
321 |
+
) -> VideoExtendCondition:
|
322 |
+
output = super()._forward(batch, override_dropout_rate)
|
323 |
+
return VideoExtendCondition(**output)
|
config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"DiffusionText2World"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "text2world_hf.DiffusionText2WorldConfig",
|
7 |
+
"AutoModel": "text2world_hf.DiffusionText2World"
|
8 |
+
},
|
9 |
+
"model_type": "AutoModel"
|
10 |
+
}
|
config.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from __future__ import annotations
|
17 |
+
|
18 |
+
from typing import Any, TypeVar
|
19 |
+
|
20 |
+
import attrs
|
21 |
+
|
22 |
+
from omegaconf import DictConfig as LazyDict
|
23 |
+
|
24 |
+
from .misc import Color
|
25 |
+
|
26 |
+
T = TypeVar("T")
|
27 |
+
|
28 |
+
|
29 |
+
def _is_attrs_instance(obj: object) -> bool:
|
30 |
+
"""
|
31 |
+
Helper function to check if an object is an instance of an attrs-defined class.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
obj: The object to check.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
bool: True if the object is an instance of an attrs-defined class, False otherwise.
|
38 |
+
"""
|
39 |
+
return hasattr(obj, "__attrs_attrs__")
|
40 |
+
|
41 |
+
|
42 |
+
def make_freezable(cls: T) -> T:
|
43 |
+
"""
|
44 |
+
A decorator that adds the capability to freeze instances of an attrs-defined class.
|
45 |
+
|
46 |
+
NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need
|
47 |
+
to hack on a "_is_frozen" attribute.
|
48 |
+
|
49 |
+
This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime.
|
50 |
+
Once an instance is frozen, its attributes cannot be changed. It also recursively freezes
|
51 |
+
any attrs-defined objects that are attributes of the class.
|
52 |
+
|
53 |
+
Usage:
|
54 |
+
@make_freezable
|
55 |
+
@attrs.define(slots=False)
|
56 |
+
class MyClass:
|
57 |
+
attribute1: int
|
58 |
+
attribute2: str
|
59 |
+
|
60 |
+
obj = MyClass(1, 'a')
|
61 |
+
obj.freeze() # Freeze the instance
|
62 |
+
obj.attribute1 = 2 # Raises AttributeError
|
63 |
+
|
64 |
+
Args:
|
65 |
+
cls: The class to be decorated.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
The decorated class with added freezing capability.
|
69 |
+
"""
|
70 |
+
|
71 |
+
if not hasattr(cls, "__dict__"):
|
72 |
+
raise TypeError(
|
73 |
+
"make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped "
|
74 |
+
"class was defined with `@attrs.define(slots=False)`"
|
75 |
+
)
|
76 |
+
|
77 |
+
original_setattr = cls.__setattr__
|
78 |
+
|
79 |
+
def setattr_override(self, key, value) -> None: # noqa: ANN001
|
80 |
+
"""
|
81 |
+
Override __setattr__ to allow modifications during initialization
|
82 |
+
and prevent modifications once the instance is frozen.
|
83 |
+
"""
|
84 |
+
if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen":
|
85 |
+
raise AttributeError("Cannot modify frozen instance")
|
86 |
+
original_setattr(self, key, value) # type: ignore
|
87 |
+
|
88 |
+
cls.__setattr__ = setattr_override # type: ignore
|
89 |
+
|
90 |
+
def freeze(self: object) -> None:
|
91 |
+
"""
|
92 |
+
Freeze the instance and all its attrs-defined attributes.
|
93 |
+
"""
|
94 |
+
for _, value in attrs.asdict(self, recurse=False).items():
|
95 |
+
if _is_attrs_instance(value) and hasattr(value, "freeze"):
|
96 |
+
value.freeze()
|
97 |
+
self._is_frozen = True # type: ignore
|
98 |
+
|
99 |
+
cls.freeze = freeze # type: ignore
|
100 |
+
|
101 |
+
return cls
|
102 |
+
|
103 |
+
|
104 |
+
def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str:
|
105 |
+
"""
|
106 |
+
Recursively pretty prints attrs objects with color.
|
107 |
+
"""
|
108 |
+
|
109 |
+
assert attrs.has(obj.__class__)
|
110 |
+
|
111 |
+
lines: list[str] = []
|
112 |
+
for attribute in attrs.fields(obj.__class__):
|
113 |
+
value = getattr(obj, attribute.name)
|
114 |
+
if attrs.has(value.__class__):
|
115 |
+
if use_color:
|
116 |
+
lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":")
|
117 |
+
else:
|
118 |
+
lines.append(" " * indent + "* " + attribute.name + ":")
|
119 |
+
lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color))
|
120 |
+
else:
|
121 |
+
if use_color:
|
122 |
+
lines.append(
|
123 |
+
" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value)
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
lines.append(" " * indent + "* " + attribute.name + ": " + str(value))
|
127 |
+
return "\n".join(lines)
|
128 |
+
|
129 |
+
|
130 |
+
@make_freezable
|
131 |
+
@attrs.define(slots=False)
|
132 |
+
class JobConfig:
|
133 |
+
# Project name.
|
134 |
+
project: str = ""
|
135 |
+
# Experiment name.
|
136 |
+
group: str = ""
|
137 |
+
# Run/job name.
|
138 |
+
name: str = ""
|
139 |
+
|
140 |
+
@property
|
141 |
+
def path(self) -> str:
|
142 |
+
return f"{self.project}/{self.group}/{self.name}"
|
143 |
+
|
144 |
+
|
145 |
+
@make_freezable
|
146 |
+
@attrs.define(slots=False)
|
147 |
+
class Config:
|
148 |
+
"""Config for a job.
|
149 |
+
|
150 |
+
See /README.md/Configuration System for more info.
|
151 |
+
"""
|
152 |
+
|
153 |
+
# Model configs.
|
154 |
+
model: LazyDict
|
155 |
+
|
156 |
+
# Training job configs.
|
157 |
+
job: JobConfig = attrs.field(factory=JobConfig)
|
158 |
+
|
159 |
+
def to_dict(self) -> dict[str, Any]:
|
160 |
+
return attrs.asdict(self)
|
161 |
+
|
162 |
+
def validate(self) -> None:
|
163 |
+
"""Validate that the config has all required fields."""
|
164 |
+
assert self.job.project != "", "Project name is required."
|
165 |
+
assert self.job.group != "", "Group name is required."
|
166 |
+
assert self.job.name != "", "Job name is required."
|
config_base_conditioner.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Dict, List, Optional
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from .conditioner import BaseConditionEntry, TextAttr, VideoConditioner, VideoExtendConditioner
|
22 |
+
from .lazy_config_init import LazyCall as L
|
23 |
+
from .lazy_config_init import LazyDict
|
24 |
+
|
25 |
+
|
26 |
+
@attrs.define(slots=False)
|
27 |
+
class TextConfig:
|
28 |
+
obj: LazyDict = L(TextAttr)() # No arguments
|
29 |
+
dropout_rate: float = 0.2
|
30 |
+
input_keys: List[str] = attrs.field(factory=lambda: ["t5_text_embeddings", "t5_text_mask"])
|
31 |
+
|
32 |
+
|
33 |
+
class BooleanFlag(BaseConditionEntry):
|
34 |
+
def __init__(self, output_key: Optional[str] = None):
|
35 |
+
super().__init__()
|
36 |
+
self.output_key = output_key
|
37 |
+
|
38 |
+
def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
|
39 |
+
del args, kwargs
|
40 |
+
key = self.output_key if self.output_key else self.input_key
|
41 |
+
return {key: self.flag}
|
42 |
+
|
43 |
+
def random_dropout_input(
|
44 |
+
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None
|
45 |
+
) -> torch.Tensor:
|
46 |
+
del key
|
47 |
+
dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate
|
48 |
+
self.flag = torch.bernoulli((1.0 - dropout_rate) * torch.ones(1)).bool().to(device=in_tensor.device)
|
49 |
+
return in_tensor
|
50 |
+
|
51 |
+
|
52 |
+
class ReMapkey(BaseConditionEntry):
|
53 |
+
def __init__(self, output_key: Optional[str] = None, dtype: Optional[str] = None):
|
54 |
+
super().__init__()
|
55 |
+
self.output_key = output_key
|
56 |
+
self.dtype = {
|
57 |
+
None: None,
|
58 |
+
"float": torch.float32,
|
59 |
+
"bfloat16": torch.bfloat16,
|
60 |
+
"half": torch.float16,
|
61 |
+
"float16": torch.float16,
|
62 |
+
"int": torch.int32,
|
63 |
+
"long": torch.int64,
|
64 |
+
}[dtype]
|
65 |
+
|
66 |
+
def forward(self, element: torch.Tensor) -> Dict[str, torch.Tensor]:
|
67 |
+
key = self.output_key if self.output_key else self.input_key
|
68 |
+
if isinstance(element, torch.Tensor):
|
69 |
+
element = element.to(dtype=self.dtype)
|
70 |
+
return {key: element}
|
71 |
+
|
72 |
+
|
73 |
+
@attrs.define(slots=False)
|
74 |
+
class FPSConfig:
|
75 |
+
"""
|
76 |
+
Remap the key from the input dictionary to the output dictionary. For `fps`.
|
77 |
+
"""
|
78 |
+
|
79 |
+
obj: LazyDict = L(ReMapkey)(output_key="fps", dtype=None)
|
80 |
+
dropout_rate: float = 0.0
|
81 |
+
input_key: str = "fps"
|
82 |
+
|
83 |
+
|
84 |
+
@attrs.define(slots=False)
|
85 |
+
class PaddingMaskConfig:
|
86 |
+
"""
|
87 |
+
Remap the key from the input dictionary to the output dictionary. For `padding_mask`.
|
88 |
+
"""
|
89 |
+
|
90 |
+
obj: LazyDict = L(ReMapkey)(output_key="padding_mask", dtype=None)
|
91 |
+
dropout_rate: float = 0.0
|
92 |
+
input_key: str = "padding_mask"
|
93 |
+
|
94 |
+
|
95 |
+
@attrs.define(slots=False)
|
96 |
+
class ImageSizeConfig:
|
97 |
+
"""
|
98 |
+
Remap the key from the input dictionary to the output dictionary. For `image_size`.
|
99 |
+
"""
|
100 |
+
|
101 |
+
obj: LazyDict = L(ReMapkey)(output_key="image_size", dtype=None)
|
102 |
+
dropout_rate: float = 0.0
|
103 |
+
input_key: str = "image_size"
|
104 |
+
|
105 |
+
|
106 |
+
@attrs.define(slots=False)
|
107 |
+
class NumFramesConfig:
|
108 |
+
"""
|
109 |
+
Remap the key from the input dictionary to the output dictionary. For `num_frames`.
|
110 |
+
"""
|
111 |
+
|
112 |
+
obj: LazyDict = L(ReMapkey)(output_key="num_frames", dtype=None)
|
113 |
+
dropout_rate: float = 0.0
|
114 |
+
input_key: str = "num_frames"
|
115 |
+
|
116 |
+
|
117 |
+
@attrs.define(slots=False)
|
118 |
+
class VideoCondBoolConfig:
|
119 |
+
obj: LazyDict = L(BooleanFlag)(output_key="video_cond_bool")
|
120 |
+
dropout_rate: float = 0.2
|
121 |
+
input_key: str = "fps" # This is a placeholder, we never use this value
|
122 |
+
# Config below are for long video generation only
|
123 |
+
|
124 |
+
# Sample PPP... from IPPP... sequence
|
125 |
+
sample_tokens_start_from_p_or_i: bool = False
|
126 |
+
|
127 |
+
|
128 |
+
@attrs.define(slots=False)
|
129 |
+
class LatentConditionConfig:
|
130 |
+
"""
|
131 |
+
Remap the key from the input dictionary to the output dictionary. For `latent condition`.
|
132 |
+
"""
|
133 |
+
|
134 |
+
obj: LazyDict = L(ReMapkey)(output_key="latent_condition", dtype=None)
|
135 |
+
dropout_rate: float = 0.0
|
136 |
+
input_key: str = "latent_condition"
|
137 |
+
|
138 |
+
|
139 |
+
@attrs.define(slots=False)
|
140 |
+
class LatentConditionSigmaConfig:
|
141 |
+
"""
|
142 |
+
Remap the key from the input dictionary to the output dictionary. For `latent condition`.
|
143 |
+
"""
|
144 |
+
|
145 |
+
obj: LazyDict = L(ReMapkey)(output_key="latent_condition_sigma", dtype=None)
|
146 |
+
dropout_rate: float = 0.0
|
147 |
+
input_key: str = "latent_condition_sigma"
|
148 |
+
|
149 |
+
|
150 |
+
BaseVideoConditionerConfig: LazyDict = L(VideoConditioner)(
|
151 |
+
text=TextConfig(),
|
152 |
+
)
|
153 |
+
|
154 |
+
VideoConditionerFpsSizePaddingConfig: LazyDict = L(VideoConditioner)(
|
155 |
+
text=TextConfig(),
|
156 |
+
fps=FPSConfig(),
|
157 |
+
num_frames=NumFramesConfig(),
|
158 |
+
image_size=ImageSizeConfig(),
|
159 |
+
padding_mask=PaddingMaskConfig(),
|
160 |
+
)
|
161 |
+
|
162 |
+
VideoExtendConditionerConfig: LazyDict = L(VideoExtendConditioner)(
|
163 |
+
text=TextConfig(),
|
164 |
+
fps=FPSConfig(),
|
165 |
+
num_frames=NumFramesConfig(),
|
166 |
+
image_size=ImageSizeConfig(),
|
167 |
+
padding_mask=PaddingMaskConfig(),
|
168 |
+
video_cond_bool=VideoCondBoolConfig(),
|
169 |
+
)
|
config_helper.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import importlib
|
17 |
+
import os
|
18 |
+
import pkgutil
|
19 |
+
import sys
|
20 |
+
from dataclasses import fields as dataclass_fields
|
21 |
+
from dataclasses import is_dataclass
|
22 |
+
from typing import Any, Dict, Optional
|
23 |
+
|
24 |
+
import attr
|
25 |
+
import attrs
|
26 |
+
from hydra import compose, initialize
|
27 |
+
from hydra.core.config_store import ConfigStore
|
28 |
+
from omegaconf import DictConfig, OmegaConf
|
29 |
+
|
30 |
+
from .log import log
|
31 |
+
from .config import Config
|
32 |
+
from .inference import *
|
33 |
+
|
34 |
+
|
35 |
+
def is_attrs_or_dataclass(obj) -> bool:
|
36 |
+
"""
|
37 |
+
Check if the object is an instance of an attrs class or a dataclass.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
obj: The object to check.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
|
44 |
+
"""
|
45 |
+
return is_dataclass(obj) or attr.has(type(obj))
|
46 |
+
|
47 |
+
|
48 |
+
def get_fields(obj):
|
49 |
+
"""
|
50 |
+
Get the fields of an attrs class or a dataclass.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
list: A list of field names.
|
57 |
+
|
58 |
+
Raises:
|
59 |
+
ValueError: If the object is neither an attrs class nor a dataclass.
|
60 |
+
"""
|
61 |
+
if is_dataclass(obj):
|
62 |
+
return [field.name for field in dataclass_fields(obj)]
|
63 |
+
elif attr.has(type(obj)):
|
64 |
+
return [field.name for field in attr.fields(type(obj))]
|
65 |
+
else:
|
66 |
+
raise ValueError("The object is neither an attrs class nor a dataclass.")
|
67 |
+
|
68 |
+
|
69 |
+
def override(config: Config, overrides: Optional[list[str]] = None) -> Config:
|
70 |
+
"""
|
71 |
+
:param config: the instance of class `Config` (usually from `make_config`)
|
72 |
+
:param overrides: list of overrides for config
|
73 |
+
:return: the composed instance of class `Config`
|
74 |
+
"""
|
75 |
+
# Store the class of the config for reconstruction after overriding.
|
76 |
+
# config_class = type(config)
|
77 |
+
|
78 |
+
# Convert Config object to a DictConfig object
|
79 |
+
config_dict = attrs.asdict(config)
|
80 |
+
config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
|
81 |
+
# Enforce "--" separator between the script arguments and overriding configs.
|
82 |
+
if overrides:
|
83 |
+
if overrides[0] != "--":
|
84 |
+
raise ValueError('Hydra config overrides must be separated with a "--" token.')
|
85 |
+
overrides = overrides[1:]
|
86 |
+
# Use Hydra to handle overrides
|
87 |
+
cs = ConfigStore.instance()
|
88 |
+
cs.store(name="config", node=config_omegaconf)
|
89 |
+
with initialize(version_base=None):
|
90 |
+
config_omegaconf = compose(config_name="config", overrides=overrides)
|
91 |
+
OmegaConf.resolve(config_omegaconf)
|
92 |
+
|
93 |
+
def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
|
94 |
+
"""
|
95 |
+
Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
|
96 |
+
|
97 |
+
Args:
|
98 |
+
ref_instance: The reference instance to determine the type and fields when needed
|
99 |
+
kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
|
103 |
+
|
104 |
+
Raises:
|
105 |
+
AssertionError: If the fields do not match or if extra keys are found.
|
106 |
+
Exception: If there is an error constructing the new instance.
|
107 |
+
"""
|
108 |
+
is_type = is_attrs_or_dataclass(ref_instance)
|
109 |
+
if not is_type:
|
110 |
+
return kwargs
|
111 |
+
else:
|
112 |
+
ref_fields = set(get_fields(ref_instance))
|
113 |
+
assert isinstance(kwargs, dict) or isinstance(
|
114 |
+
kwargs, DictConfig
|
115 |
+
), "kwargs must be a dictionary or a DictConfig"
|
116 |
+
keys = set(kwargs.keys())
|
117 |
+
|
118 |
+
# ref_fields must equal to or include all keys
|
119 |
+
extra_keys = keys - ref_fields
|
120 |
+
assert ref_fields == keys or keys.issubset(
|
121 |
+
ref_fields
|
122 |
+
), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
|
123 |
+
|
124 |
+
resolved_kwargs: Dict[str, Any] = {}
|
125 |
+
for f in keys:
|
126 |
+
resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
|
127 |
+
try:
|
128 |
+
new_instance = type(ref_instance)(**resolved_kwargs)
|
129 |
+
except Exception as e:
|
130 |
+
log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
|
131 |
+
log.error(e)
|
132 |
+
raise e
|
133 |
+
return new_instance
|
134 |
+
|
135 |
+
config = config_from_dict(config, config_omegaconf)
|
136 |
+
|
137 |
+
return config
|
138 |
+
|
139 |
+
|
140 |
+
def get_config_module(config_file: str) -> str:
|
141 |
+
if not config_file.endswith(".py"):
|
142 |
+
log.error("Config file cannot be specified as module.")
|
143 |
+
log.error("Please provide the path to the Python config file (relative to the Cosmos root).")
|
144 |
+
assert os.path.isfile(config_file), f"Cosmos config file ({config_file}) not found."
|
145 |
+
# Convert to importable module format.
|
146 |
+
config_module = config_file.replace("/", ".").replace(".py", "")
|
147 |
+
return config_module
|
148 |
+
|
149 |
+
|
150 |
+
def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None:
|
151 |
+
"""
|
152 |
+
Import all modules from the specified package path recursively.
|
153 |
+
|
154 |
+
This function is typically used in conjunction with Hydra to ensure that all modules
|
155 |
+
within a specified package are imported, which is necessary for registering configurations.
|
156 |
+
|
157 |
+
Example usage:
|
158 |
+
```python
|
159 |
+
import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True, skip_underscore=False)
|
160 |
+
```
|
161 |
+
|
162 |
+
Args:
|
163 |
+
package_path (str): The dotted path to the package from which to import all modules.
|
164 |
+
reload (bool): Flag to determine whether to reload modules if they're already imported.
|
165 |
+
skip_underscore (bool): If True, skips importing modules that start with an underscore.
|
166 |
+
"""
|
167 |
+
return # we do not use this function
|
168 |
+
log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
|
169 |
+
package = importlib.import_module(package_path)
|
170 |
+
package_directory = package.__path__
|
171 |
+
|
172 |
+
def import_modules_recursively(directory: str, prefix: str) -> None:
|
173 |
+
"""
|
174 |
+
Recursively imports or reloads all modules in the given directory.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
directory (str): The file system path to the current package directory.
|
178 |
+
prefix (str): The module prefix (e.g., 'cosmos1.models.diffusion.config').
|
179 |
+
"""
|
180 |
+
for _, module_name, is_pkg in pkgutil.iter_modules([directory]):
|
181 |
+
if skip_underscore and module_name.startswith("_"):
|
182 |
+
log.debug(f"Skipping module {module_name} as it starts with an underscore")
|
183 |
+
continue
|
184 |
+
|
185 |
+
full_module_name = f"{prefix}.{module_name}"
|
186 |
+
log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}")
|
187 |
+
|
188 |
+
if full_module_name in sys.modules and reload:
|
189 |
+
importlib.reload(sys.modules[full_module_name])
|
190 |
+
else:
|
191 |
+
importlib.import_module(full_module_name)
|
192 |
+
|
193 |
+
if is_pkg:
|
194 |
+
sub_package_directory = os.path.join(directory, module_name)
|
195 |
+
import_modules_recursively(sub_package_directory, full_module_name)
|
196 |
+
|
197 |
+
for directory in package_directory:
|
198 |
+
import_modules_recursively(directory, package_path)
|
convert_pixtral_ckpt.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Convert pretrained Pixtral vision model weights to checkpoint and verify the checkpoint loading.
|
17 |
+
|
18 |
+
Usage:
|
19 |
+
|
20 |
+
PYTHONPATH=$(pwd) python cosmos1/scripts/convert_pixtral_ckpt.py
|
21 |
+
"""
|
22 |
+
|
23 |
+
import argparse
|
24 |
+
import json
|
25 |
+
import os
|
26 |
+
import shutil
|
27 |
+
from glob import glob
|
28 |
+
|
29 |
+
import torch
|
30 |
+
from huggingface_hub import snapshot_download
|
31 |
+
from safetensors.torch import load_file
|
32 |
+
|
33 |
+
|
34 |
+
def convert_pixtral_checkpoint(checkpoint_dir: str, checkpoint_name: str, vit_type: str):
|
35 |
+
"""
|
36 |
+
Main function to convert Pixtral vision model weights to checkpoint and optionally verify and save the converted checkpoint.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
checkpoint_dir (str): Path to the checkpoint directory
|
40 |
+
checkpoint_name (str): Name of the checkpoint
|
41 |
+
vit_type (str): Type of ViT used in the Pixtral model
|
42 |
+
|
43 |
+
This function performs the following steps:
|
44 |
+
0. Download the checkpoint from Hugging Face
|
45 |
+
1. Loads the original Pixtral checkpoint
|
46 |
+
2. Splits the checkpoint into vision encoder, projector, and LLM weights
|
47 |
+
3. Reorganizes the weights to match the expected format
|
48 |
+
4. Extracts and verifies the vision encoder configuration
|
49 |
+
5. Optionally verifies the converted checkpoint by loading it into a VisionTransformer
|
50 |
+
6. Optionally saves the converted checkpoint and configuration
|
51 |
+
"""
|
52 |
+
|
53 |
+
save_dir = os.path.join(checkpoint_dir, checkpoint_name)
|
54 |
+
os.makedirs(save_dir, exist_ok=True)
|
55 |
+
# Save the converted checkpoint
|
56 |
+
save_path = os.path.join(save_dir, "model.pt")
|
57 |
+
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
|
58 |
+
print(f"Checkpoint {save_path} already exists and is not empty")
|
59 |
+
return
|
60 |
+
|
61 |
+
pixtral_ckpt_dir = os.path.join(checkpoint_dir, "Pixtral-12B-2409")
|
62 |
+
os.makedirs(pixtral_ckpt_dir, exist_ok=True)
|
63 |
+
repo_id = "mistralai/Pixtral-12B-2409"
|
64 |
+
print(f"Downloading {repo_id} to {pixtral_ckpt_dir}...")
|
65 |
+
snapshot_download(
|
66 |
+
repo_id=repo_id,
|
67 |
+
allow_patterns=["params.json", "consolidated.safetensors"],
|
68 |
+
local_dir=pixtral_ckpt_dir,
|
69 |
+
local_dir_use_symlinks=False,
|
70 |
+
)
|
71 |
+
orig_dtype = torch.get_default_dtype()
|
72 |
+
dtype = torch.bfloat16
|
73 |
+
torch.set_default_dtype(dtype)
|
74 |
+
|
75 |
+
# Load checkpoint file
|
76 |
+
ckpt_files = glob(os.path.join(pixtral_ckpt_dir, "*.safetensors"))
|
77 |
+
assert len(ckpt_files) == 1, "ckpt_dir should contain only one file"
|
78 |
+
ckpt_path = ckpt_files[0]
|
79 |
+
ckpt = load_file(ckpt_path)
|
80 |
+
|
81 |
+
# Split checkpoint into weights of vision encoder, projector, and LLM
|
82 |
+
vit_key_prefix = "vision_encoder."
|
83 |
+
vit_ckpt = {}
|
84 |
+
for key, value in ckpt.items():
|
85 |
+
if key.startswith(vit_key_prefix):
|
86 |
+
vit_ckpt[key.lstrip(vit_key_prefix)] = value
|
87 |
+
|
88 |
+
projector_key_prefix = "vision_language_adapter."
|
89 |
+
projector_ckpt = {}
|
90 |
+
substring_replacement_map = {
|
91 |
+
"w_in.": "projector.0.",
|
92 |
+
"w_out.": "projector.2.",
|
93 |
+
}
|
94 |
+
for key, value in ckpt.items():
|
95 |
+
if key.startswith(projector_key_prefix):
|
96 |
+
key = key.lstrip(projector_key_prefix)
|
97 |
+
for old, new in substring_replacement_map.items():
|
98 |
+
key = key.replace(old, new)
|
99 |
+
projector_ckpt[key] = value
|
100 |
+
|
101 |
+
llm_ckpt = {}
|
102 |
+
for key, value in ckpt.items():
|
103 |
+
if key.startswith(vit_key_prefix) or key.startswith(projector_key_prefix):
|
104 |
+
continue
|
105 |
+
llm_ckpt[key] = value
|
106 |
+
|
107 |
+
vlm_ckpt = {}
|
108 |
+
for key, value in llm_ckpt.items():
|
109 |
+
vlm_ckpt["model." + key] = value
|
110 |
+
for key, value in projector_ckpt.items():
|
111 |
+
vlm_ckpt["mm_projector." + key] = value
|
112 |
+
for key, value in vit_ckpt.items():
|
113 |
+
vlm_ckpt["vision_encoder." + key] = value
|
114 |
+
|
115 |
+
# Load config
|
116 |
+
config_path = os.path.join(pixtral_ckpt_dir, "params.json")
|
117 |
+
with open(config_path, "r") as f:
|
118 |
+
pixtral_config = json.load(f)
|
119 |
+
|
120 |
+
# Extract the vision encoder configuration
|
121 |
+
vision_encoder_config = {
|
122 |
+
"dim": pixtral_config["vision_encoder"]["hidden_size"],
|
123 |
+
"num_channels": pixtral_config["vision_encoder"]["num_channels"],
|
124 |
+
"image_size": pixtral_config["vision_encoder"]["image_size"],
|
125 |
+
"patch_size": pixtral_config["vision_encoder"]["patch_size"],
|
126 |
+
"rope_theta": pixtral_config["vision_encoder"]["rope_theta"],
|
127 |
+
"ffn_hidden_size": pixtral_config["vision_encoder"]["intermediate_size"],
|
128 |
+
"n_layers": pixtral_config["vision_encoder"]["num_hidden_layers"],
|
129 |
+
"n_heads": pixtral_config["vision_encoder"]["num_attention_heads"],
|
130 |
+
"n_kv_heads": pixtral_config["vision_encoder"]["num_attention_heads"],
|
131 |
+
"norm_type": "rmsnorm",
|
132 |
+
"norm_eps": pixtral_config["norm_eps"],
|
133 |
+
"image_token_id": pixtral_config["vision_encoder"]["image_token_id"],
|
134 |
+
}
|
135 |
+
# Configuration for the 400M ViT of Pixtral 12B VLM
|
136 |
+
vit_config = dict(
|
137 |
+
dim=1024,
|
138 |
+
num_channels=3,
|
139 |
+
image_size=1024,
|
140 |
+
patch_size=16,
|
141 |
+
rope_theta=10000,
|
142 |
+
ffn_hidden_size=4096,
|
143 |
+
n_layers=24,
|
144 |
+
n_heads=16,
|
145 |
+
n_kv_heads=16,
|
146 |
+
norm_type="rmsnorm",
|
147 |
+
norm_eps=1e-5,
|
148 |
+
image_token_id=10,
|
149 |
+
)
|
150 |
+
# Compare the two configurations
|
151 |
+
for key, value in vit_config.items():
|
152 |
+
assert vision_encoder_config[key] == value, f"Mismatch in {key}: {vision_encoder_config[key]} != {value}"
|
153 |
+
|
154 |
+
llm_config_keys = [
|
155 |
+
"dim",
|
156 |
+
"n_layers",
|
157 |
+
"head_dim",
|
158 |
+
"hidden_dim",
|
159 |
+
"n_heads",
|
160 |
+
"n_kv_heads",
|
161 |
+
"rope_theta",
|
162 |
+
"norm_eps",
|
163 |
+
"vocab_size",
|
164 |
+
]
|
165 |
+
assert set(list(pixtral_config.keys())) == set(llm_config_keys + ["vision_encoder"]), "Config keys mismatch"
|
166 |
+
replace_map = {
|
167 |
+
"hidden_dim": "ffn_hidden_size",
|
168 |
+
}
|
169 |
+
llm_config = {}
|
170 |
+
for k, v in pixtral_config.items():
|
171 |
+
if k in llm_config_keys:
|
172 |
+
llm_config[replace_map.get(k, k)] = v
|
173 |
+
elif k == "vision_encoder":
|
174 |
+
llm_config["vision_encoder"] = vit_type
|
175 |
+
else:
|
176 |
+
raise ValueError(f"Unknown key: {k}")
|
177 |
+
|
178 |
+
ckpt_to_save = {"model": vlm_ckpt, "mm_projector": projector_ckpt, "vision_encoder": vit_ckpt}
|
179 |
+
torch.save(ckpt_to_save, save_path)
|
180 |
+
print(f"Model saved to {save_path}")
|
181 |
+
|
182 |
+
# Save config
|
183 |
+
config_path = os.path.join(save_dir, "config.json")
|
184 |
+
with open(config_path, "w") as f:
|
185 |
+
json.dump(llm_config, f)
|
186 |
+
|
187 |
+
torch.set_default_dtype(orig_dtype) # Reset the default dtype
|
188 |
+
|
189 |
+
# Remove the original Pixtral checkpoint
|
190 |
+
shutil.rmtree(pixtral_ckpt_dir, ignore_errors=True)
|
191 |
+
print(f"Removed {pixtral_ckpt_dir}")
|
192 |
+
|
193 |
+
|
194 |
+
if __name__ == "__main__":
|
195 |
+
parser = argparse.ArgumentParser(
|
196 |
+
description="Convert pretrained Pixtral vision model weights to checkpoint and verify accuracy"
|
197 |
+
)
|
198 |
+
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Path to the checkpoint directory")
|
199 |
+
parser.add_argument(
|
200 |
+
"--checkpoint_name",
|
201 |
+
type=str,
|
202 |
+
default="Pixtral-12B",
|
203 |
+
help="Name of the checkpoint",
|
204 |
+
)
|
205 |
+
parser.add_argument("--vit_type", default="pixtral-12b-vit", help="Type of ViT used in the Pixtral model")
|
206 |
+
args = parser.parse_args()
|
207 |
+
convert_pixtral_checkpoint(
|
208 |
+
checkpoint_dir=args.checkpoint_dir, checkpoint_name=args.checkpoint_name, vit_type=args.vit_type
|
209 |
+
)
|
cosmos1/models/POST_TRAINING.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cosmos Post-training
|
2 |
+
|
3 |
+
In the [Cosmos paper](https://research.nvidia.com/publication/2025-01_cosmos-world-foundation-model-platform-physical-ai), we discuss several post-training examples of Cosmos pre-trained World Foundation Models (WFMs) for various Physical AI tasks, including
|
4 |
+
|
5 |
+
- General Post-Training: Fine-tune the WFM to generate a target distribution of videos based on the custom dataset. The target distribution could include a specific camera spec or a specific domain such as a factory.
|
6 |
+
- Instruction Control: Post-trains models for robotic manipulation to predict videos based on textual instructions, enabling robots to visually simulate tasks like folding clothes or picking up objects.
|
7 |
+
- Action Control: Post-trains models for robotic manipulation to predict the next visual frame based on action vectors, simulating robotic tasks like object handling or movement planning.
|
8 |
+
- Camera Control: Adds camera pose conditioning to generate 3D-consistent video simulations from single images, enabling joystick-like navigation in virtual environments.
|
9 |
+
- Multi-View Generation: Post-trains models for autonomous vehicles to generate synchronized multi-view videos from text prompts, simulating driving scenarios with multiple camera perspectives.
|
10 |
+
- Multi-View Generation with Vehicle Trajectory Control: Extends multi-view generation by incorporating trajectory inputs, enabling precise simulation of driving environments for autonomous vehicles, adhering to specified paths.
|
11 |
+
|
12 |
+
Except for the instruction control where the WFM is post-trained on a dataset of instruction-video pairs, all other cases require minor modifications of the network architectures. Post-training tasks will be supported by NeMo Framework. In this initial release, we provide post-training scripts for the general post-training of both diffusion and autorgressive WFMs. Scripts of the other post-training tasks will be provided in a future release.
|
13 |
+
|
14 |
+
## Post-training Support Matrix
|
15 |
+
|
16 |
+
| Post-training Task | Diffusion WFM | Autoregressive WFM |
|
17 |
+
|---------------------|---------------|--------------------|
|
18 |
+
| General post-training | [Supported](../models/diffusion/nemo/post_training/README.md) | [Supported](../models/autoregressive/nemo/post_training/README.md) |
|
19 |
+
| Instruction control | Coming soon | Coming soon |
|
20 |
+
| Action control | Coming soon | Coming soon |
|
21 |
+
| Camera control | Coming soon | Coming soon |
|
22 |
+
| Multi-view generation | Coming soon | Coming soon |
|
23 |
+
| Multi-view generation with vehicle trajectory control | Coming soon | Coming soon |
|
cosmos1/models/autoregressive/README.md
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cosmos Autoregressive-based World Foundation Models
|
2 |
+
|
3 |
+
## Table of Contents
|
4 |
+
- [Getting Started](#getting-started)
|
5 |
+
- [Set Up Docker Environment](#set-up-docker-environment)
|
6 |
+
- [Download Checkpoints](#download-checkpoints)
|
7 |
+
- [Usage](#usage)
|
8 |
+
- [Model Types](#model-types)
|
9 |
+
- [Single and Batch Generation](#single-and-batch-generation)
|
10 |
+
- [Sample Commands](#sample-commands)
|
11 |
+
- [Base Models (4B/12B)](#base-basepy-4b-and-12b)
|
12 |
+
- [Video2World Models (5B/13B)](#video2world-video2worldpy-5b-and-13b)
|
13 |
+
- [Arguments](#arguments)
|
14 |
+
- [Common Parameters](#common-parameters)
|
15 |
+
- [Base Specific Parameters](#base-specific-parameters)
|
16 |
+
- [Video2World Specific Parameters](#video2world-specific-parameters)
|
17 |
+
- [Safety Features](#safety-features)
|
18 |
+
|
19 |
+
This page details the steps for using the Cosmos autoregressive-based world foundation models.
|
20 |
+
|
21 |
+
## Getting Started
|
22 |
+
|
23 |
+
### Set Up Docker Environment
|
24 |
+
|
25 |
+
Follow our [Installation Guide](../../../INSTALL.md) to set up the Docker environment. All commands on this page should be run inside Docker.
|
26 |
+
|
27 |
+
### Download Checkpoints
|
28 |
+
|
29 |
+
1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained').
|
30 |
+
|
31 |
+
2. Log in to Hugging Face with the access token:
|
32 |
+
|
33 |
+
```bash
|
34 |
+
huggingface-cli login
|
35 |
+
```
|
36 |
+
|
37 |
+
3. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-6751e884dc10e013a0a0d8e6):
|
38 |
+
|
39 |
+
```bash
|
40 |
+
PYTHONPATH=$(pwd) python cosmos1/scripts/download_autoregressive.py --model_sizes 4B 5B 12B 13B
|
41 |
+
```
|
42 |
+
|
43 |
+
4. The downloaded files should be in the following structure:
|
44 |
+
|
45 |
+
```
|
46 |
+
checkpoints/
|
47 |
+
├── Cosmos-1.0-Autoregressive-4B
|
48 |
+
│ ├── model.pt
|
49 |
+
│ └── config.json
|
50 |
+
├── Cosmos-1.0-Autoregressive-5B-Video2World
|
51 |
+
│ ├── model.pt
|
52 |
+
│ └── config.json
|
53 |
+
├── Cosmos-1.0-Autoregressive-12B
|
54 |
+
│ ├── model.pt
|
55 |
+
│ └── config.json
|
56 |
+
├── Cosmos-1.0-Autoregressive-13B-Video2World
|
57 |
+
│ ├── model.pt
|
58 |
+
│ └── config.json
|
59 |
+
├── Cosmos-1.0-Tokenizer-CV8x8x8
|
60 |
+
│ ├── decoder.jit
|
61 |
+
│ ├── encoder.jit
|
62 |
+
│ └── mean_std.pt
|
63 |
+
├── Cosmos-1.0-Tokenizer-DV8x16x16
|
64 |
+
│ ├── decoder.jit
|
65 |
+
│ └── encoder.jit
|
66 |
+
├── Cosmos-1.0-Diffusion-7B-Decoder-DV8x16x16ToCV8x8x8
|
67 |
+
│ ├── aux_vars.pt
|
68 |
+
│ └── model.pt
|
69 |
+
└── Cosmos-1.0-Guardrail
|
70 |
+
├── aegis/
|
71 |
+
├── blocklist/
|
72 |
+
├── face_blur_filter/
|
73 |
+
└── video_content_safety_filter/
|
74 |
+
```
|
75 |
+
|
76 |
+
## Usage
|
77 |
+
|
78 |
+
|
79 |
+
### Model Types
|
80 |
+
|
81 |
+
There are two model types available for autoregressive world generation:
|
82 |
+
|
83 |
+
1. **Base**: Supports world generation from image/video input
|
84 |
+
|
85 |
+
* Models: `Cosmos-1.0-Autoregressive-4B` and `Cosmos-1.0-Autoregressive-12B`
|
86 |
+
* Inference script: [base.py](/cosmos1/models/autoregressive/inference/base.py)
|
87 |
+
|
88 |
+
2. **Video2World**: Supports world generation from image/video input and text input
|
89 |
+
|
90 |
+
* Models: `Cosmos-1.0-Autoregressive-5B-Video2World` and `Cosmos-1.0-Autoregressive-13B-Video2World`
|
91 |
+
* Inference script: [video2world.py](/cosmos1/models/autoregressive/inference/video2world.py)
|
92 |
+
|
93 |
+
Our models now support video extension up to 33 frames. Starting from either a single image or a 9-frame video input, they can generate the remaining frames to reach the 33-frame length (generating 32 or 24 frames, respectively).
|
94 |
+
|
95 |
+
We have evaluated all eight possible configurations (4 models × 2 vision input types: image or video) using 100 test videos on physical AI topics. Below are the failure rates for each configuration:
|
96 |
+
|
97 |
+
| Model | Image input | Video input (9 frames) |
|
98 |
+
|:------------------------------------------|:--------------:|:-------------------------:|
|
99 |
+
| Cosmos-1.0-Autoregressive-4B | 15% | 1% |
|
100 |
+
| Cosmos-1.0-Autoregressive-5B-Video2World | 7% | 2% |
|
101 |
+
| Cosmos-1.0-Autoregressive-12B | 2% | 1% |
|
102 |
+
| Cosmos-1.0-Autoregressive-13B-Video2World | 3% | 0% |
|
103 |
+
|
104 |
+
We define failure cases as videos with severe distortions, such as:
|
105 |
+
|
106 |
+
* Sudden appearance of large unexpected objects
|
107 |
+
* Video degrading to a single solid color
|
108 |
+
|
109 |
+
Note that the following are not considered failures in our analysis:
|
110 |
+
|
111 |
+
* Static video frames
|
112 |
+
* Minor object distortions or artifacts
|
113 |
+
|
114 |
+
### Single and Batch Generation
|
115 |
+
|
116 |
+
We support both single and batch video generation.
|
117 |
+
|
118 |
+
For generating a single video, `base` mode requires the input argument `--input_image_or_video_path` (image/video input), while `video2world` mode requires both `--input_image_or_video_path` (image/video input) and `--prompt` (text input).
|
119 |
+
|
120 |
+
Note that our model only works with 1024x640 resolution videos. If the input image/video is not in this resolution, it will be resized and cropped.
|
121 |
+
|
122 |
+
For generating a batch of videos, both `base` and `video2world` require `--batch_input_path` (path to a JSONL file). For `base`, the JSONL file should contain one visual input per line in the following format, where each line must contain a "visual_input" field:
|
123 |
+
|
124 |
+
```json
|
125 |
+
{"visual_input": "path/to/video1.mp4"}
|
126 |
+
{"visual_input": "path/to/video2.mp4"}
|
127 |
+
```
|
128 |
+
|
129 |
+
For `video2world`, each line in the JSONL file must contain both "prompt" and "visual_input" fields:
|
130 |
+
|
131 |
+
```json
|
132 |
+
{"prompt": "prompt1", "visual_input": "path/to/video1.mp4"}
|
133 |
+
{"prompt": "prompt2", "visual_input": "path/to/video2.mp4"}
|
134 |
+
```
|
135 |
+
|
136 |
+
### Sample Commands
|
137 |
+
|
138 |
+
There are two main demo scripts for autoregressive world generation: `base.py` and `video2world.py`. Below you will find sample commands for single and batch generation, as well as commands for running with low-memory GPUs using model offloading. We also provide a memory usage table comparing different offloading strategies to help with configuration.
|
139 |
+
|
140 |
+
#### Base (base.py): 4B and 12B
|
141 |
+
|
142 |
+
Generates world from image/video input.
|
143 |
+
|
144 |
+
The `input_type` argument can be either `video` or `image`. We have tuned the sampling parameters `top_p` and `temperature` to achieve the best performance. Please use the provided values in the command examples.
|
145 |
+
|
146 |
+
Note that the command examples below all use video input. If you want to use image input, please change the `input_type` to `image`.
|
147 |
+
|
148 |
+
##### Single Generation
|
149 |
+
|
150 |
+
```bash
|
151 |
+
# Example using 4B model
|
152 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \
|
153 |
+
--input_type=video \
|
154 |
+
--input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
|
155 |
+
--video_save_name=Cosmos-1.0-Autoregressive-4B \
|
156 |
+
--ar_model_dir=Cosmos-1.0-Autoregressive-4B \
|
157 |
+
--top_p=0.8 \
|
158 |
+
--temperature=1.0
|
159 |
+
|
160 |
+
# Example for low-memory GPUs using 4B model with model offloading
|
161 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \
|
162 |
+
--input_type=video \
|
163 |
+
--input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
|
164 |
+
--video_save_name=Cosmos-1.0-Autoregressive-4B \
|
165 |
+
--ar_model_dir=Cosmos-1.0-Autoregressive-4B \
|
166 |
+
--top_p=0.8 \
|
167 |
+
--temperature=1.0 \
|
168 |
+
--offload_guardrail_models \
|
169 |
+
--offload_diffusion_decoder \
|
170 |
+
--offload_ar_model \
|
171 |
+
--offload_tokenizer
|
172 |
+
|
173 |
+
# Example using 12B model
|
174 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \
|
175 |
+
--input_type=video \
|
176 |
+
--input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
|
177 |
+
--video_save_name=Cosmos-1.0-Autoregressive-12B \
|
178 |
+
--ar_model_dir=Cosmos-1.0-Autoregressive-12B \
|
179 |
+
--top_p=0.9 \
|
180 |
+
--temperature=1.0
|
181 |
+
|
182 |
+
# Example for low-memory GPUs using 12B model with model offloading
|
183 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \
|
184 |
+
--input_type=video \
|
185 |
+
--input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
|
186 |
+
--video_save_name=Cosmos-1.0-Autoregressive-12B \
|
187 |
+
--ar_model_dir=Cosmos-1.0-Autoregressive-12B \
|
188 |
+
--top_p=0.9 \
|
189 |
+
--temperature=1.0 \
|
190 |
+
--offload_guardrail_models \
|
191 |
+
--offload_diffusion_decoder \
|
192 |
+
--offload_ar_model \
|
193 |
+
--offload_tokenizer
|
194 |
+
```
|
195 |
+
|
196 |
+
##### Batch Generation
|
197 |
+
|
198 |
+
```bash
|
199 |
+
# Example using 4B model
|
200 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \
|
201 |
+
--input_type=video \
|
202 |
+
--batch_input_path=cosmos1/models/autoregressive/assets/v1p0/batch_inputs/base.jsonl \
|
203 |
+
--video_save_folder=outputs/Cosmos-1.0-Autoregressive-4B \
|
204 |
+
--ar_model_dir=Cosmos-1.0-Autoregressive-4B \
|
205 |
+
--top_p=0.8 \
|
206 |
+
--temperature=1.0
|
207 |
+
|
208 |
+
# Example using 12B model
|
209 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \
|
210 |
+
--input_type=video \
|
211 |
+
--batch_input_path=cosmos1/models/autoregressive/assets/v1p0/batch_inputs/base.jsonl \
|
212 |
+
--video_save_folder=outputs/Cosmos-1.0-Autoregressive-12B \
|
213 |
+
--ar_model_dir=Cosmos-1.0-Autoregressive-12B \
|
214 |
+
--top_p=0.9 \
|
215 |
+
--temperature=1.0
|
216 |
+
```
|
217 |
+
|
218 |
+
##### Example Output
|
219 |
+
|
220 |
+
Here is an example output video generated using base.py with image input, using `Cosmos-1.0-Autoregressive-12B`:
|
221 |
+
|
222 |
+
<video src="https://github.com/user-attachments/assets/634403a5-1873-42d7-8dd0-eb7fb4ac8cf4">
|
223 |
+
Your browser does not support the video tag.
|
224 |
+
</video>
|
225 |
+
|
226 |
+
The input image used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.jpg`. The image is from [BDD dataset](http://bdd-data.berkeley.edu/).
|
227 |
+
|
228 |
+
Here is an example output video generated using base.py with 9-frame video input, using `Cosmos-1.0-Autoregressive-12B`:
|
229 |
+
|
230 |
+
<video src="https://github.com/user-attachments/assets/1a3ff099-87d7-41e8-b149-a25cfcd4f40b">
|
231 |
+
Your browser does not support the video tag.
|
232 |
+
</video>
|
233 |
+
|
234 |
+
The input video used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.mp4`.
|
235 |
+
|
236 |
+
##### Inference Time and GPU Memory Usage
|
237 |
+
|
238 |
+
These numbers may vary based on system specifications and are provided for reference only.
|
239 |
+
|
240 |
+
| Offloading Strategy | Cosmos-1.0-Autoregressive-4B | Cosmos-1.0-Autoregressive-12B |
|
241 |
+
|-------------|---------|---------|
|
242 |
+
| No offloading | 31.3 GB | 47.5 GB |
|
243 |
+
| Guardrails | 28.9 GB | 45.2 GB |
|
244 |
+
| Guardrails & Diffusion decoder | 28.5 GB | 43.1 GB |
|
245 |
+
| Guardrails & Diffusion decoder & Tokenizer | 27.3 GB | 42.9 GB |
|
246 |
+
| Guardrails & Diffusion decoder & Tokenizer & AR model | 18.7 GB | 27.4 GB |
|
247 |
+
|
248 |
+
End-to-end inference runtime on one H100 without offloading and after model initialization:
|
249 |
+
|
250 |
+
| Cosmos-1.0-Autoregressive-4B | Cosmos-1.0-Autoregressive-12B |
|
251 |
+
|---------|---------|
|
252 |
+
| ~62 seconds | ~119 seconds |
|
253 |
+
|
254 |
+
#### Video2World (video2world.py): 5B and 13B
|
255 |
+
|
256 |
+
Generates world from image/video and text input.
|
257 |
+
|
258 |
+
The `input_type` argument can be either `text_and_video` or `text_and_image`. We have tuned the sampling parameters `top_p` and `temperature` to achieve the best performance. Please use the provided values in the command examples.
|
259 |
+
|
260 |
+
Note that the command examples below all use video input. If you want to use image input, please change the `input_type` to `text_and_image`.
|
261 |
+
|
262 |
+
##### Single Generation
|
263 |
+
|
264 |
+
```bash
|
265 |
+
# Example using 5B model
|
266 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \
|
267 |
+
--input_type=text_and_video \
|
268 |
+
--input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
|
269 |
+
--prompt="A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." \
|
270 |
+
--video_save_name=Cosmos-1.0-Autoregressive-5B-Video2World \
|
271 |
+
--ar_model_dir=Cosmos-1.0-Autoregressive-5B-Video2World \
|
272 |
+
--top_p=0.7 \
|
273 |
+
--temperature=1.0
|
274 |
+
|
275 |
+
# Example for low-memory GPUs using 5B model with model offloading
|
276 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \
|
277 |
+
--input_type=text_and_video \
|
278 |
+
--input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
|
279 |
+
--prompt="A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." \
|
280 |
+
--video_save_name=Cosmos-1.0-Autoregressive-5B-Video2World \
|
281 |
+
--ar_model_dir=Cosmos-1.0-Autoregressive-5B-Video2World \
|
282 |
+
--top_p=0.7 \
|
283 |
+
--temperature=1.0 \
|
284 |
+
--offload_guardrail_models \
|
285 |
+
--offload_diffusion_decoder \
|
286 |
+
--offload_ar_model \
|
287 |
+
--offload_tokenizer \
|
288 |
+
--offload_text_encoder_model
|
289 |
+
|
290 |
+
# Example using 13B model
|
291 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \
|
292 |
+
--input_type=text_and_video \
|
293 |
+
--input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
|
294 |
+
--prompt="A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." \
|
295 |
+
--video_save_name=Cosmos-1.0-Autoregressive-13B-Video2World \
|
296 |
+
--ar_model_dir=Cosmos-1.0-Autoregressive-13B-Video2World \
|
297 |
+
--top_p=0.8 \
|
298 |
+
--temperature=1.0 \
|
299 |
+
--offload_guardrail_models
|
300 |
+
|
301 |
+
# Example for low-memory GPUs using 13B model with model offloading
|
302 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \
|
303 |
+
--input_type=text_and_video \
|
304 |
+
--input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
|
305 |
+
--prompt="A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." \
|
306 |
+
--video_save_name=Cosmos-1.0-Autoregressive-13B-Video2World \
|
307 |
+
--ar_model_dir=Cosmos-1.0-Autoregressive-13B-Video2World \
|
308 |
+
--top_p=0.8 \
|
309 |
+
--temperature=1.0 \
|
310 |
+
--offload_guardrail_models \
|
311 |
+
--offload_diffusion_decoder \
|
312 |
+
--offload_ar_model \
|
313 |
+
--offload_tokenizer \
|
314 |
+
--offload_text_encoder_model
|
315 |
+
```
|
316 |
+
|
317 |
+
##### Batch Generation
|
318 |
+
|
319 |
+
```bash
|
320 |
+
# Example using 5B model
|
321 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \
|
322 |
+
--input_type=text_and_video \
|
323 |
+
--batch_input_path=cosmos1/models/autoregressive/assets/v1p0/batch_inputs/video2world.jsonl \
|
324 |
+
--video_save_folder=outputs/Cosmos-1.0-Autoregressive-5B-Video2World \
|
325 |
+
--ar_model_dir=Cosmos-1.0-Autoregressive-5B-Video2World \
|
326 |
+
--top_p=0.7 \
|
327 |
+
--temperature=1.0
|
328 |
+
|
329 |
+
# Example using 13B model
|
330 |
+
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \
|
331 |
+
--input_type=text_and_video \
|
332 |
+
--batch_input_path=cosmos1/models/autoregressive/assets/v1p0/batch_inputs/video2world.jsonl \
|
333 |
+
--video_save_folder=outputs/Cosmos-1.0-Autoregressive-13B-Video2World \
|
334 |
+
--ar_model_dir=Cosmos-1.0-Autoregressive-13B-Video2World \
|
335 |
+
--top_p=0.8 \
|
336 |
+
--temperature=1.0 \
|
337 |
+
--offload_guardrail_models
|
338 |
+
```
|
339 |
+
|
340 |
+
##### Example Output
|
341 |
+
|
342 |
+
Here is an example output video generated using video2world.py with image input, using `Cosmos-1.0-Autoregressive-13B-Video2World`:
|
343 |
+
|
344 |
+
<video src="https://github.com/user-attachments/assets/869f3b81-fabd-462e-a545-c04cdd9c1d22">
|
345 |
+
Your browser does not support the video tag.
|
346 |
+
</video>
|
347 |
+
|
348 |
+
The input image used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.jpg`. The prompt for generating the video is:
|
349 |
+
|
350 |
+
```
|
351 |
+
A driving video captures a serene urban street scene on a sunny day. The camera is mounted on the dashboard of a moving vehicle, providing a first-person perspective as it travels down a two-lane road. The street is lined with parked cars on both sides, predominantly black and silver sedans and SUVs. The road is flanked by a mix of residential and commercial buildings, with a prominent red-brick building on the left side, featuring multiple windows and a flat roof. The sky is clear with a few scattered clouds, casting soft shadows on the street. Trees with lush green foliage line the right side of the road, providing a natural contrast to the urban environment. The camera remains steady, maintaining a consistent forward motion, suggesting a leisurely drive. Traffic is light, with a few vehicles moving in the opposite direction, including a black sedan and a yellow taxi. Street signs are visible, including a no-parking sign on the right. The overall atmosphere is calm and peaceful, with no pedestrians visible, emphasizing the focus on the drive and the surrounding urban landscape.
|
352 |
+
```
|
353 |
+
|
354 |
+
Here is an example output video generated using video2world.py with 9-frame video input, using `Cosmos-1.0-Autoregressive-13B-Video2World`:
|
355 |
+
|
356 |
+
<video src="https://github.com/user-attachments/assets/81840e1c-624b-4b01-9240-ab7db3722e58">
|
357 |
+
Your browser does not support the video tag.
|
358 |
+
</video>
|
359 |
+
|
360 |
+
The input video used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.mp4`. The prompt for generating the video is:
|
361 |
+
|
362 |
+
```
|
363 |
+
A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions.
|
364 |
+
```
|
365 |
+
|
366 |
+
##### Inference Time and GPU Memory Usage
|
367 |
+
|
368 |
+
These numbers may vary based on system specifications and are provided for reference only.
|
369 |
+
|
370 |
+
| Offloading Strategy | Cosmos-1.0-Autoregressive-5B-Video2World | Cosmos-1.0-Autoregressive-13B-Video2World |
|
371 |
+
|-------------|---------|---------|
|
372 |
+
| No offloading | 66.2 GB | > 80 GB |
|
373 |
+
| Guardrails | 58.7 GB | 76.6 GB |
|
374 |
+
| Guardrails & T5 encoder | 41.3 GB | 58.0 GB |
|
375 |
+
| Guardrails & T5 encoder & Diffusion decoder | 29.0 GB | 46.9 GB |
|
376 |
+
| Guardrails & T5 encoder & Diffusion decoder & Tokenizer | 28.8 GB | 46.7 GB |
|
377 |
+
| Guardrails & T5 encoder & Diffusion decoder & Tokenizer & AR model | 21.1 GB | 30.9 GB |
|
378 |
+
|
379 |
+
End-to-end inference runtime on one H100 with no offloading for 5B model and guardrail offloading for 13B, after model initialization:
|
380 |
+
|
381 |
+
| Cosmos-1.0-Autoregressive-5B-Video2World | Cosmos-1.0-Autoregressive-13B-Video2World |
|
382 |
+
|---------|---------|
|
383 |
+
| ~73 seconds | ~150 seconds |
|
384 |
+
|
385 |
+
### Arguments
|
386 |
+
|
387 |
+
#### Common Parameters
|
388 |
+
|
389 |
+
| Parameter | Description | Default |
|
390 |
+
|-----------|-------------|---------|
|
391 |
+
| `--checkpoint_dir` | Directory containing model weights | "checkpoints" |
|
392 |
+
| `--video_save_name` | Output video filename for single video generation | "output" |
|
393 |
+
| `--video_save_folder` | Folder where all output videos are stored | "outputs/" |
|
394 |
+
| `--input_image_or_video_path` | Input image or video path. Required for single video generation | None |
|
395 |
+
| `--batch_input_path` | Folder containing input images or videos. Required for batch video generation | None |
|
396 |
+
| `--num_input_frames` | Number of input frames to use for Video2World prediction | 9 |
|
397 |
+
| `--temperature` | Temperature used while sampling | 1.0 (recommend using values in sample commands provided) |
|
398 |
+
| `--top_p` | Top-p value for top-p sampling | 0.8 (recommend using values in sample commands provided) |
|
399 |
+
| `--seed` | Random seed | 0 |
|
400 |
+
| `--disable_diffusion_decoder` | When set to True, use discrete tokenizer to decode discrete tokens to video. Otherwise, use diffusion decoder to decode video | False |
|
401 |
+
| `--offload_guardrail_models` | Offload guardrail models after inference, used for low-memory GPUs | False |
|
402 |
+
| `--offload_diffusion_decoder` | Offload diffusion decoder after inference, used for low-memory GPUs | False |
|
403 |
+
| `--offload_ar_model` | Offload AR model after inference, used for low-memory GPUs | False |
|
404 |
+
| `--offload_prompt_upsampler` | Offload prompt upsampler after inference, used for low-memory GPUs | False |
|
405 |
+
|
406 |
+
#### Base Specific Parameters
|
407 |
+
|
408 |
+
| Parameter | Description | Default |
|
409 |
+
|-----------|-------------|---------|
|
410 |
+
| `--ar_model_dir` | Directory containing AR model weight | "Cosmos-1.0-Autoregressive-4B" |
|
411 |
+
| `--input_type` | Input type, either `video` or `image` | "video" |
|
412 |
+
|
413 |
+
#### Video2World Specific Parameters
|
414 |
+
|
415 |
+
| Parameter | Description | Default |
|
416 |
+
|-----------|-------------|---------|
|
417 |
+
| `--ar_model_dir` | Directory containing AR model weight | "Cosmos-1.0-Autoregressive-4B" |
|
418 |
+
| `--input_type` | Input type, either `text_and_video` or `text_and_image` | "text_and_video" |
|
419 |
+
| `--prompt` | Text prompt for single video generation. Required for single video generation | None |
|
420 |
+
| `--input_prompts_path` | Path to JSONL file for batch video generation. Required for batch video generation | None |
|
421 |
+
| `--offload_text_encoder_model` | Offload text encoder after inference, used for low-memory GPUs | False |
|
422 |
+
|
423 |
+
### Safety Features
|
424 |
+
|
425 |
+
The model uses a built-in safety guardrail system that cannot be disabled. Generating human faces is not allowed and will be blurred by the guardrail.
|
426 |
+
|
427 |
+
For more information, check out the [Cosmos Guardrail Documentation](../guardrail/README.md).
|
cosmos1/models/autoregressive/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos1/models/autoregressive/assets/nemo/finetuned_result.mp4
ADDED
Binary file (193 kB). View file
|
|
cosmos1/models/autoregressive/assets/v1p0/batch_inputs/0.mp4
ADDED
Binary file (299 kB). View file
|
|
cosmos1/models/autoregressive/assets/v1p0/batch_inputs/1.mp4
ADDED
Binary file (222 kB). View file
|
|
cosmos1/models/autoregressive/assets/v1p0/batch_inputs/2.mp4
ADDED
Binary file (511 kB). View file
|
|
cosmos1/models/autoregressive/assets/v1p0/batch_inputs/3.mp4
ADDED
Binary file (461 kB). View file
|
|
cosmos1/models/autoregressive/assets/v1p0/batch_inputs/4.mp4
ADDED
Binary file (331 kB). View file
|
|
cosmos1/models/autoregressive/assets/v1p0/batch_inputs/5.mp4
ADDED
Binary file (282 kB). View file
|
|
cosmos1/models/autoregressive/assets/v1p0/batch_inputs/6.mp4
ADDED
Binary file (289 kB). View file
|
|
cosmos1/models/autoregressive/assets/v1p0/batch_inputs/7.mp4
ADDED
Binary file (170 kB). View file
|
|
cosmos1/models/autoregressive/assets/v1p0/batch_inputs/8.mp4
ADDED
Binary file (188 kB). View file
|
|
cosmos1/models/autoregressive/assets/v1p0/batch_inputs/9.mp4
ADDED
Binary file (174 kB). View file
|
|