Spaces:
Running
Running
Merge branch 'main' of https://github.com/borisdayma/dalle-mini into add-custom-model
Browse files- .github/workflows/check_size.yml +17 -0
- .github/workflows/style.yml +20 -0
- .github/workflows/sync_to_hub.yml +20 -0
- .github/workflows/sync_to_hub_debug.yml +17 -0
- .gitignore +4 -0
- CITATION.cff +44 -0
- LICENSE +201 -0
- Makefile +5 -0
- README.md +144 -30
- app/gradio/app_gradio.py +179 -0
- app/gradio/requirements.txt +4 -0
- app/streamlit/app.py +117 -0
- app/streamlit/img/loading.gif +0 -0
- dalle_mini/data.py +261 -0
- dalle_mini/dataset.py +0 -122
- dalle_mini/model.py +64 -0
- dalle_mini/text.py +258 -0
- dalle_mini/vqgan_jax/__init__.py +0 -0
- dalle_mini/vqgan_jax/configuration_vqgan.py +0 -40
- dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +0 -109
- dalle_mini/vqgan_jax/modeling_flax_vqgan.py +0 -609
- data/CC12M_downloader.py +0 -91
- data/CC3M_downloader.py +0 -62
- demo/CustomBARTv4b_model-generate.ipynb +0 -566
- demo/demo_notebook.ipynb +0 -583
- encoding/vqgan-jax-encoding-with-captions.ipynb +0 -363
- encoding/vqgan-jax-encoding-yfcc100m.ipynb +0 -1136
- encoding/vqgan-jax-encoding.ipynb +0 -0
- environment.yaml +0 -10
- img/logo.png +0 -0
- model/data-pipeline.ipynb +0 -385
- pyproject.toml +2 -0
- requirements.txt +0 -9
- seq2seq/do_big_run.sh +0 -16
- seq2seq/do_small_run.sh +0 -16
- seq2seq/requirements.txt +0 -8
- seq2seq/run_seq2seq_flax.py +0 -897
- setup.cfg +27 -0
- setup.py +4 -0
- tools/dataset/encode_dataset.ipynb +371 -0
- tools/inference/inference_pipeline.ipynb +0 -0
- tools/inference/log_inference_samples.ipynb +434 -0
- tools/inference/samples.txt +124 -0
- {seq2seq → tools/train}/sweep.yaml +34 -23
- tools/train/train.py +857 -0
.github/workflows/check_size.yml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Check file size
|
2 |
+
|
3 |
+
on:
|
4 |
+
pull_request:
|
5 |
+
branches: [main]
|
6 |
+
|
7 |
+
# to run this workflow manually from the Actions tab
|
8 |
+
workflow_dispatch:
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
sync-to-hub:
|
12 |
+
runs-on: ubuntu-latest
|
13 |
+
steps:
|
14 |
+
- name: Check large files
|
15 |
+
uses: ActionsDesk/lfs-warning@v2.0
|
16 |
+
with:
|
17 |
+
filesizelimit: 10485760 # = 10MB, so we can sync to HF spaces
|
.github/workflows/style.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Lint
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [main]
|
6 |
+
pull_request:
|
7 |
+
branches: [main]
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
lint:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v2
|
14 |
+
- uses: psf/black@stable
|
15 |
+
- uses: actions/setup-python@v2
|
16 |
+
with:
|
17 |
+
python-version: 3.9
|
18 |
+
- name: Install requirements
|
19 |
+
run: pip install ".[dev]"
|
20 |
+
- uses: jamescurtin/isort-action@master
|
.github/workflows/sync_to_hub.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face hub
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [main]
|
6 |
+
|
7 |
+
# to run this workflow manually from the Actions tab
|
8 |
+
workflow_dispatch:
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
sync-to-hub:
|
12 |
+
runs-on: ubuntu-latest
|
13 |
+
steps:
|
14 |
+
- uses: actions/checkout@v2
|
15 |
+
with:
|
16 |
+
fetch-depth: 0
|
17 |
+
- name: Push to hub
|
18 |
+
env:
|
19 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
20 |
+
run: git push https://boris:$HF_TOKEN@huggingface.co/spaces/flax-community/dalle-mini main
|
.github/workflows/sync_to_hub_debug.yml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Deploy to debug app
|
2 |
+
|
3 |
+
on:
|
4 |
+
# to run this workflow manually from the Actions tab
|
5 |
+
workflow_dispatch:
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
sync-to-hub-debug:
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
steps:
|
11 |
+
- uses: actions/checkout@v2
|
12 |
+
with:
|
13 |
+
fetch-depth: 0
|
14 |
+
- name: Push to hub
|
15 |
+
env:
|
16 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
17 |
+
run: git push --force https://boris:$HF_TOKEN@huggingface.co/spaces/flax-community/dalle-mini-debug +HEAD:main
|
.gitignore
CHANGED
@@ -1 +1,5 @@
|
|
1 |
__pycache__
|
|
|
|
|
|
|
|
|
|
1 |
__pycache__
|
2 |
+
.ipynb_checkpoints
|
3 |
+
.streamlit
|
4 |
+
wandb/
|
5 |
+
*.egg-info/
|
CITATION.cff
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# YAML 1.2
|
2 |
+
---
|
3 |
+
abstract: "DALL·E mini is a JAX/Flax reimplementation of OpenAI's DALL·E that requires much smaller hardware resources. By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models, we were able to create a model that is 27 times smaller than the original DALL·E and train it on a single TPU v3-8 for only 3 days. DALL·E mini achieves impressive results, albeit of a lower quality than the original system. It can be used for exploration and further experimentation on commodity hardware."
|
4 |
+
authors:
|
5 |
+
-
|
6 |
+
family-names: Dayma
|
7 |
+
given-names: Boris
|
8 |
+
-
|
9 |
+
family-names: Patil
|
10 |
+
given-names: Suraj
|
11 |
+
-
|
12 |
+
family-names: Cuenca
|
13 |
+
given-names: Pedro
|
14 |
+
-
|
15 |
+
family-names: Saifullah
|
16 |
+
given-names: Khalid
|
17 |
+
-
|
18 |
+
family-names: Abraham
|
19 |
+
given-names: Tanishq
|
20 |
+
-
|
21 |
+
family-names: "Lê Khắc"
|
22 |
+
given-names: "Phúc"
|
23 |
+
-
|
24 |
+
family-names: Melas
|
25 |
+
given-names: Luke
|
26 |
+
-
|
27 |
+
family-names: Ghosh
|
28 |
+
given-names: Ritobrata
|
29 |
+
cff-version: "1.1.0"
|
30 |
+
date-released: 2021-07-29
|
31 |
+
identifiers:
|
32 |
+
keywords:
|
33 |
+
- dalle
|
34 |
+
- "text-to-image generation"
|
35 |
+
- transformer
|
36 |
+
- "zero-shot"
|
37 |
+
- JAX
|
38 |
+
license: "Apache-2.0"
|
39 |
+
doi: 10.5281/zenodo.5146400
|
40 |
+
message: "If you use this project, please cite it using these metadata."
|
41 |
+
repository-code: "https://github.com/borisdayma/dalle-mini"
|
42 |
+
title: "DALL·E Mini"
|
43 |
+
version: "v0.1-alpha"
|
44 |
+
...
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2021 The DALL·E mini Authors
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
Makefile
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: style
|
2 |
+
|
3 |
+
style:
|
4 |
+
black .
|
5 |
+
isort .
|
README.md
CHANGED
@@ -1,42 +1,156 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
|
5 |
-
|
6 |
-
* [Conceptual 12M](https://github.com/google-research-datasets/conceptual-12m) Dataset (already loaded and preprocessed in TPU VM by Luke).
|
7 |
-
* [YFCC100M Subset](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md)
|
8 |
-
* [Coneptual Captions 3M](https://github.com/google-research-datasets/conceptual-captions)
|
9 |
|
10 |
-
|
11 |
-
* Use the Taming Transformers VQ-GAN (with 16384 tokens)
|
12 |
-
* Use a seq2seq (language encoder --> image decoder) model with a pretrained non-autoregressive encoder (e.g. BERT) and an autoregressive decoder (like GPT).
|
13 |
|
14 |
-
|
15 |
-
* Whether to freeze the text encoder?
|
16 |
-
* Whether to finetune the VQ-GAN?
|
17 |
-
* Which text encoder to use (e.g. BERT, RoBERTa, etc.)?
|
18 |
-
* Hyperparameter choices for the decoder (e.g. positional embedding, initialization, etc.)
|
19 |
|
20 |
-
|
21 |
|
22 |
-
|
23 |
-
* work on dataset loading - [see suggested datasets](https://discuss.huggingface.co/t/dall-e-mini-version/7324/4)
|
24 |
-
* Optionally create the OpenAI YFCC100M subset (see [this post](https://discuss.huggingface.co/t/dall-e-mini-version/7324/30?u=boris))
|
25 |
-
* work on text/image encoding
|
26 |
-
* concatenate inputs (not sure if we need fixed length for text or use a special token separating text & image)
|
27 |
-
* adapt training script
|
28 |
-
* create inference function
|
29 |
-
* integrate CLIP for better results (only if we have the time)
|
30 |
-
* work on a demo (streamlit or colab or maybe just HF widget)
|
31 |
-
* document (set up repo on model hub per instructions, start on README writeup…)
|
32 |
-
* help with coordinating activities & progress
|
33 |
|
|
|
34 |
|
35 |
-
|
36 |
-
You should create a new python virtual environment and install the project dependencies inside the virtual env. You need to use the `-f` (`--find-links`) option for `pip` to be able to find the appropriate `libtpu` required for the TPU hardware:
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
```
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
```
|
41 |
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: DALL·E mini
|
3 |
+
emoji: 🥑
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: green
|
6 |
+
sdk: streamlit
|
7 |
+
app_file: app/streamlit/app.py
|
8 |
+
pinned: True
|
9 |
+
---
|
10 |
|
11 |
+
# DALL·E Mini
|
12 |
|
13 |
+
[![Join us on Discord](https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white)](https://discord.gg/xBPBXfcFHd)
|
|
|
|
|
|
|
14 |
|
15 |
+
_Generate images from a text prompt_
|
|
|
|
|
16 |
|
17 |
+
<img src="img/logo.png" width="200">
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
Our logo was generated with DALL·E mini using the prompt "logo of an armchair in the shape of an avocado".
|
20 |
|
21 |
+
You can create your own pictures with [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
## How does it work?
|
24 |
|
25 |
+
Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA).
|
|
|
26 |
|
27 |
+
## Development
|
28 |
+
|
29 |
+
### Dependencies Installation
|
30 |
+
|
31 |
+
For inference only, use `pip install git+https://github.com/borisdayma/dalle-mini.git`.
|
32 |
+
|
33 |
+
For development, clone the repo and use `pip install -e ".[dev]"`.
|
34 |
+
|
35 |
+
### Training of VQGAN
|
36 |
+
|
37 |
+
The VQGAN was trained using [taming-transformers](https://github.com/CompVis/taming-transformers).
|
38 |
+
|
39 |
+
We recommend using the latest version available.
|
40 |
+
|
41 |
+
### Conversion of VQGAN to JAX
|
42 |
+
|
43 |
+
Use [patil-suraj/vqgan-jax](https://github.com/patil-suraj/vqgan-jax).
|
44 |
+
|
45 |
+
### Training of Seq2Seq
|
46 |
+
|
47 |
+
Use [`tools/train/train.py`](tools/train/train.py).
|
48 |
+
|
49 |
+
You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.
|
50 |
+
|
51 |
+
### Inference Pipeline
|
52 |
+
|
53 |
+
To generate sample predictions and understand the inference pipeline step by step, refer to [`tools/inference/inference_pipeline.ipynb`](tools/inference/inference_pipeline.ipynb).
|
54 |
+
|
55 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb)
|
56 |
+
|
57 |
+
## FAQ
|
58 |
+
|
59 |
+
### Where to find the latest models?
|
60 |
+
|
61 |
+
Trained models are on 🤗 Model Hub:
|
62 |
+
|
63 |
+
- [VQGAN-f16-16384](https://huggingface.co/flax-community/vqgan_f16_16384) for encoding/decoding images
|
64 |
+
- [DALL·E mini](https://huggingface.co/flax-community/dalle-mini) for generating images from a text prompt
|
65 |
+
|
66 |
+
### Where does the logo come from?
|
67 |
+
|
68 |
+
The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone to us.
|
69 |
+
|
70 |
+
## Authors & Contributors
|
71 |
+
|
72 |
+
### Main Authors
|
73 |
+
|
74 |
+
- [Boris Dayma](https://github.com/borisdayma)
|
75 |
+
- [Suraj Patil](https://github.com/patil-suraj)
|
76 |
+
- [Pedro Cuenca](https://github.com/pcuenca)
|
77 |
+
|
78 |
+
### Other Members of dalle-mini team during FLAX/JAX community week
|
79 |
+
|
80 |
+
- [Khalid Saifullah](https://github.com/khalidsaifullaah)
|
81 |
+
- [Tanishq Abraham](https://github.com/tmabraham)
|
82 |
+
- [Phúc Lê Khắc](https://github.com/lkhphuc)
|
83 |
+
- [Luke Melas](https://github.com/lukemelas)
|
84 |
+
- [Ritobrata Ghosh](https://github.com/ghosh-r)
|
85 |
+
|
86 |
+
### Contributing
|
87 |
+
|
88 |
+
Join the community on the [DALLE-Pytorch Discord](https://discord.gg/xBPBXfcFHd).
|
89 |
+
Any contribution is welcome, from reporting issues to proposing fixes/improvements or testing the model on cool prompts!
|
90 |
+
|
91 |
+
## Acknowledgements
|
92 |
+
|
93 |
+
- 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects)
|
94 |
+
- Google [TPU Research Cloud (TRC) program](https://sites.research.google/trc/) for providing computing resources
|
95 |
+
- [Weights & Biases](https://wandb.com/) for providing the infrastructure for experiment tracking and model management
|
96 |
+
|
97 |
+
## Citing DALL·E mini
|
98 |
+
|
99 |
+
If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry.
|
100 |
+
|
101 |
+
```
|
102 |
+
@misc{Dayma_DALL·E_Mini_2021,
|
103 |
+
author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
|
104 |
+
doi = {10.5281/zenodo.5146400},
|
105 |
+
month = {7},
|
106 |
+
title = {DALL·E Mini},
|
107 |
+
url = {https://github.com/borisdayma/dalle-mini},
|
108 |
+
year = {2021}
|
109 |
+
}
|
110 |
```
|
111 |
+
|
112 |
+
## References
|
113 |
+
|
114 |
+
```
|
115 |
+
@misc{ramesh2021zeroshot,
|
116 |
+
title={Zero-Shot Text-to-Image Generation},
|
117 |
+
author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
|
118 |
+
year={2021},
|
119 |
+
eprint={2102.12092},
|
120 |
+
archivePrefix={arXiv},
|
121 |
+
primaryClass={cs.CV}
|
122 |
+
}
|
123 |
+
```
|
124 |
+
|
125 |
+
```
|
126 |
+
@misc{esser2021taming,
|
127 |
+
title={Taming Transformers for High-Resolution Image Synthesis},
|
128 |
+
author={Patrick Esser and Robin Rombach and Björn Ommer},
|
129 |
+
year={2021},
|
130 |
+
eprint={2012.09841},
|
131 |
+
archivePrefix={arXiv},
|
132 |
+
primaryClass={cs.CV}
|
133 |
+
}
|
134 |
```
|
135 |
|
136 |
+
```
|
137 |
+
@misc{lewis2019bart,
|
138 |
+
title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
|
139 |
+
author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
|
140 |
+
year={2019},
|
141 |
+
eprint={1910.13461},
|
142 |
+
archivePrefix={arXiv},
|
143 |
+
primaryClass={cs.CL}
|
144 |
+
}
|
145 |
+
```
|
146 |
+
|
147 |
+
```
|
148 |
+
@misc{radford2021learning,
|
149 |
+
title={Learning Transferable Visual Models From Natural Language Supervision},
|
150 |
+
author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
|
151 |
+
year={2021},
|
152 |
+
eprint={2103.00020},
|
153 |
+
archivePrefix={arXiv},
|
154 |
+
primaryClass={cs.CV}
|
155 |
+
}
|
156 |
+
```
|
app/gradio/app_gradio.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# Uncomment to run on cpu
|
5 |
+
# import os
|
6 |
+
# os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
7 |
+
|
8 |
+
import random
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import jax
|
12 |
+
import numpy as np
|
13 |
+
from flax.jax_utils import replicate
|
14 |
+
from flax.training.common_utils import shard
|
15 |
+
from PIL import Image, ImageDraw, ImageFont
|
16 |
+
|
17 |
+
# ## CLIP Scoring
|
18 |
+
from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
|
19 |
+
from vqgan_jax.modeling_flax_vqgan import VQModel
|
20 |
+
|
21 |
+
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
22 |
+
|
23 |
+
DALLE_REPO = "flax-community/dalle-mini"
|
24 |
+
DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
|
25 |
+
|
26 |
+
VQGAN_REPO = "flax-community/vqgan_f16_16384"
|
27 |
+
VQGAN_COMMIT_ID = "90cc46addd2dd8f5be21586a9a23e1b95aa506a9"
|
28 |
+
|
29 |
+
tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
|
30 |
+
model = CustomFlaxBartForConditionalGeneration.from_pretrained(
|
31 |
+
DALLE_REPO, revision=DALLE_COMMIT_ID
|
32 |
+
)
|
33 |
+
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
|
34 |
+
|
35 |
+
|
36 |
+
def captioned_strip(images, caption=None, rows=1):
|
37 |
+
increased_h = 0 if caption is None else 48
|
38 |
+
w, h = images[0].size[0], images[0].size[1]
|
39 |
+
img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
|
40 |
+
for i, img_ in enumerate(images):
|
41 |
+
img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
|
42 |
+
|
43 |
+
if caption is not None:
|
44 |
+
draw = ImageDraw.Draw(img)
|
45 |
+
font = ImageFont.truetype(
|
46 |
+
"/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
|
47 |
+
)
|
48 |
+
draw.text((20, 3), caption, (255, 255, 255), font=font)
|
49 |
+
return img
|
50 |
+
|
51 |
+
|
52 |
+
def custom_to_pil(x):
|
53 |
+
x = np.clip(x, 0.0, 1.0)
|
54 |
+
x = (255 * x).astype(np.uint8)
|
55 |
+
x = Image.fromarray(x)
|
56 |
+
if not x.mode == "RGB":
|
57 |
+
x = x.convert("RGB")
|
58 |
+
return x
|
59 |
+
|
60 |
+
|
61 |
+
def generate(input, rng, params):
|
62 |
+
return model.generate(
|
63 |
+
**input,
|
64 |
+
max_length=257,
|
65 |
+
num_beams=1,
|
66 |
+
do_sample=True,
|
67 |
+
prng_key=rng,
|
68 |
+
eos_token_id=50000,
|
69 |
+
pad_token_id=50000,
|
70 |
+
params=params,
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
def get_images(indices, params):
|
75 |
+
return vqgan.decode_code(indices, params=params)
|
76 |
+
|
77 |
+
|
78 |
+
p_generate = jax.pmap(generate, "batch")
|
79 |
+
p_get_images = jax.pmap(get_images, "batch")
|
80 |
+
|
81 |
+
bart_params = replicate(model.params)
|
82 |
+
vqgan_params = replicate(vqgan.params)
|
83 |
+
|
84 |
+
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
85 |
+
print("Initialize FlaxCLIPModel")
|
86 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
87 |
+
print("Initialize CLIPProcessor")
|
88 |
+
|
89 |
+
|
90 |
+
def hallucinate(prompt, num_images=64):
|
91 |
+
prompt = [prompt] * jax.device_count()
|
92 |
+
inputs = tokenizer(
|
93 |
+
prompt,
|
94 |
+
return_tensors="jax",
|
95 |
+
padding="max_length",
|
96 |
+
truncation=True,
|
97 |
+
max_length=128,
|
98 |
+
).data
|
99 |
+
inputs = shard(inputs)
|
100 |
+
|
101 |
+
all_images = []
|
102 |
+
for i in range(num_images // jax.device_count()):
|
103 |
+
key = random.randint(0, 1e7)
|
104 |
+
rng = jax.random.PRNGKey(key)
|
105 |
+
rngs = jax.random.split(rng, jax.local_device_count())
|
106 |
+
indices = p_generate(inputs, rngs, bart_params).sequences
|
107 |
+
indices = indices[:, :, 1:]
|
108 |
+
|
109 |
+
images = p_get_images(indices, vqgan_params)
|
110 |
+
images = np.squeeze(np.asarray(images), 1)
|
111 |
+
for image in images:
|
112 |
+
all_images.append(custom_to_pil(image))
|
113 |
+
return all_images
|
114 |
+
|
115 |
+
|
116 |
+
def clip_top_k(prompt, images, k=8):
|
117 |
+
inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
|
118 |
+
outputs = clip(**inputs)
|
119 |
+
logits = outputs.logits_per_text
|
120 |
+
scores = np.array(logits[0]).argsort()[-k:][::-1]
|
121 |
+
return [images[score] for score in scores]
|
122 |
+
|
123 |
+
|
124 |
+
def compose_predictions(images, caption=None):
|
125 |
+
increased_h = 0 if caption is None else 48
|
126 |
+
w, h = images[0].size[0], images[0].size[1]
|
127 |
+
img = Image.new("RGB", (len(images) * w, h + increased_h))
|
128 |
+
for i, img_ in enumerate(images):
|
129 |
+
img.paste(img_, (i * w, increased_h))
|
130 |
+
|
131 |
+
if caption is not None:
|
132 |
+
draw = ImageDraw.Draw(img)
|
133 |
+
font = ImageFont.truetype(
|
134 |
+
"/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
|
135 |
+
)
|
136 |
+
draw.text((20, 3), caption, (255, 255, 255), font=font)
|
137 |
+
return img
|
138 |
+
|
139 |
+
|
140 |
+
def top_k_predictions(prompt, num_candidates=32, k=8):
|
141 |
+
images = hallucinate(prompt, num_images=num_candidates)
|
142 |
+
images = clip_top_k(prompt, images, k=k)
|
143 |
+
return images
|
144 |
+
|
145 |
+
|
146 |
+
def run_inference(prompt, num_images=32, num_preds=8):
|
147 |
+
images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
|
148 |
+
predictions = captioned_strip(images)
|
149 |
+
output_title = f"""
|
150 |
+
<b>{prompt}</b>
|
151 |
+
"""
|
152 |
+
return (output_title, predictions)
|
153 |
+
|
154 |
+
|
155 |
+
outputs = [
|
156 |
+
gr.outputs.HTML(label=""), # To be used as title
|
157 |
+
gr.outputs.Image(label=""),
|
158 |
+
]
|
159 |
+
|
160 |
+
description = """
|
161 |
+
DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
|
162 |
+
"""
|
163 |
+
gr.Interface(
|
164 |
+
run_inference,
|
165 |
+
inputs=[gr.inputs.Textbox(label="What do you want to see?")],
|
166 |
+
outputs=outputs,
|
167 |
+
title="DALL·E mini",
|
168 |
+
description=description,
|
169 |
+
article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
|
170 |
+
layout="vertical",
|
171 |
+
theme="huggingface",
|
172 |
+
examples=[
|
173 |
+
["an armchair in the shape of an avocado"],
|
174 |
+
["snowy mountains by the sea"],
|
175 |
+
],
|
176 |
+
allow_flagging=False,
|
177 |
+
live=False,
|
178 |
+
# server_port=8999
|
179 |
+
).launch(share=True)
|
app/gradio/requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Requirements for huggingface spaces
|
2 |
+
gradio>=2.2.3
|
3 |
+
flax
|
4 |
+
transformers
|
app/streamlit/app.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import base64
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
import requests
|
8 |
+
import streamlit as st
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
|
12 |
+
class ServiceError(Exception):
|
13 |
+
def __init__(self, status_code):
|
14 |
+
self.status_code = status_code
|
15 |
+
|
16 |
+
|
17 |
+
def get_images_from_backend(prompt, backend_url):
|
18 |
+
r = requests.post(backend_url, json={"prompt": prompt})
|
19 |
+
if r.status_code == 200:
|
20 |
+
images = r.json()["images"]
|
21 |
+
images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
|
22 |
+
return images
|
23 |
+
else:
|
24 |
+
raise ServiceError(r.status_code)
|
25 |
+
|
26 |
+
|
27 |
+
st.sidebar.markdown(
|
28 |
+
"""
|
29 |
+
<style>
|
30 |
+
.aligncenter {
|
31 |
+
text-align: center;
|
32 |
+
}
|
33 |
+
</style>
|
34 |
+
<p class="aligncenter">
|
35 |
+
<img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/img/logo.png"/>
|
36 |
+
</p>
|
37 |
+
""",
|
38 |
+
unsafe_allow_html=True,
|
39 |
+
)
|
40 |
+
st.sidebar.markdown(
|
41 |
+
"""
|
42 |
+
___
|
43 |
+
<p style='text-align: center'>
|
44 |
+
DALL·E mini is an AI model that generates images from any prompt you give!
|
45 |
+
</p>
|
46 |
+
|
47 |
+
<p style='text-align: center'>
|
48 |
+
Created by Boris Dayma et al. 2021
|
49 |
+
<br/>
|
50 |
+
<a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
|
51 |
+
</p>
|
52 |
+
""",
|
53 |
+
unsafe_allow_html=True,
|
54 |
+
)
|
55 |
+
|
56 |
+
st.header("DALL·E mini")
|
57 |
+
st.subheader("Generate images from text")
|
58 |
+
|
59 |
+
prompt = st.text_input("What do you want to see?")
|
60 |
+
|
61 |
+
DEBUG = False
|
62 |
+
if prompt != "":
|
63 |
+
container = st.empty()
|
64 |
+
container.markdown(
|
65 |
+
f"""
|
66 |
+
<style> p {{ margin:0 }} div {{ margin:0 }} </style>
|
67 |
+
<div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
|
68 |
+
<div class="stAlert">
|
69 |
+
<div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
|
70 |
+
<div class="st-b7">
|
71 |
+
<div class="css-whx05o e13vu3m50">
|
72 |
+
<div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
|
73 |
+
<img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
|
74 |
+
Generating predictions for: <b>{prompt}</b>
|
75 |
+
</div>
|
76 |
+
</div>
|
77 |
+
</div>
|
78 |
+
</div>
|
79 |
+
</div>
|
80 |
+
</div>
|
81 |
+
<small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
|
82 |
+
""",
|
83 |
+
unsafe_allow_html=True,
|
84 |
+
)
|
85 |
+
|
86 |
+
try:
|
87 |
+
backend_url = st.secrets["BACKEND_SERVER"]
|
88 |
+
print(f"Getting selections: {prompt}")
|
89 |
+
selected = get_images_from_backend(prompt, backend_url)
|
90 |
+
|
91 |
+
margin = 0.1 # for better position of zoom in arrow
|
92 |
+
n_columns = 3
|
93 |
+
cols = st.columns([1] + [margin, 1] * (n_columns - 1))
|
94 |
+
for i, img in enumerate(selected):
|
95 |
+
cols[(i % n_columns) * 2].image(img)
|
96 |
+
container.markdown(f"**{prompt}**")
|
97 |
+
|
98 |
+
st.button("Again!", key="again_button")
|
99 |
+
|
100 |
+
except ServiceError as error:
|
101 |
+
container.text(f"Service unavailable, status: {error.status_code}")
|
102 |
+
except KeyError:
|
103 |
+
if DEBUG:
|
104 |
+
container.markdown(
|
105 |
+
"""
|
106 |
+
**Error: BACKEND_SERVER unset**
|
107 |
+
|
108 |
+
Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL:
|
109 |
+
```
|
110 |
+
BACKEND_SERVER="<server url>"
|
111 |
+
```
|
112 |
+
"""
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
container.markdown(
|
116 |
+
"Error -5, please try again or [report it](mailto:pcuenca-dalle@guenever.net)."
|
117 |
+
)
|
app/streamlit/img/loading.gif
ADDED
dalle_mini/data.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import numpy as np
|
7 |
+
from datasets import Dataset, load_dataset
|
8 |
+
from flax.training.common_utils import shard
|
9 |
+
|
10 |
+
from .text import TextNormalizer
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class Dataset:
|
15 |
+
dataset_repo_or_path: str
|
16 |
+
train_file: str = None
|
17 |
+
validation_file: str = None
|
18 |
+
dataset_type: str = "dataset"
|
19 |
+
streaming: bool = True
|
20 |
+
use_auth_token: bool = False
|
21 |
+
text_column: str = "caption"
|
22 |
+
encoding_column: str = "encoding"
|
23 |
+
max_source_length: int = 128
|
24 |
+
max_train_samples: int = None
|
25 |
+
max_eval_samples: int = None
|
26 |
+
preprocessing_num_workers: int = None
|
27 |
+
overwrite_cache: bool = False
|
28 |
+
do_train: bool = False
|
29 |
+
do_eval: bool = True
|
30 |
+
seed_dataset: int = None
|
31 |
+
train_dataset: Dataset = field(init=False)
|
32 |
+
eval_dataset: Dataset = field(init=False)
|
33 |
+
rng_dataset: jnp.ndarray = field(init=False)
|
34 |
+
|
35 |
+
def __post_init__(self):
|
36 |
+
# define data_files
|
37 |
+
if self.train_file is not None or self.validation_file is not None:
|
38 |
+
data_files = {
|
39 |
+
"train": self.train_file,
|
40 |
+
"validation": self.validation_file,
|
41 |
+
}
|
42 |
+
else:
|
43 |
+
data_files = None
|
44 |
+
|
45 |
+
# load dataset
|
46 |
+
dataset = load_dataset(
|
47 |
+
self.dataset_repo_or_path,
|
48 |
+
data_files=data_files,
|
49 |
+
streaming=self.streaming,
|
50 |
+
use_auth_token=self.use_auth_token,
|
51 |
+
)
|
52 |
+
if self.do_train:
|
53 |
+
if "train" not in dataset:
|
54 |
+
raise ValueError("Training requires a training dataset")
|
55 |
+
self.train_dataset = dataset["train"]
|
56 |
+
if self.max_train_samples is not None:
|
57 |
+
self.train_dataset = (
|
58 |
+
self.train_dataset.take(self.max_train_samples)
|
59 |
+
if self.streaming
|
60 |
+
else self.train_dataset.select(range(self.max_train_samples))
|
61 |
+
)
|
62 |
+
if self.do_eval:
|
63 |
+
if "validation" not in dataset:
|
64 |
+
raise ValueError("Evaluating requires a validation dataset")
|
65 |
+
self.eval_dataset = dataset["validation"]
|
66 |
+
if self.max_eval_samples is not None:
|
67 |
+
self.eval_dataset = (
|
68 |
+
self.eval_dataset.take(self.max_eval_samples)
|
69 |
+
if self.streaming
|
70 |
+
else self.eval_dataset.select(range(self.max_eval_samples))
|
71 |
+
)
|
72 |
+
|
73 |
+
def preprocess(self, tokenizer, decoder_start_token_id, normalize_text):
|
74 |
+
if self.streaming:
|
75 |
+
# we need to shuffle early in streaming mode
|
76 |
+
if hasattr(self, "train_dataset"):
|
77 |
+
self.train_dataset = self.train_dataset.shuffle(1000, self.seed_dataset)
|
78 |
+
else:
|
79 |
+
# prepare rng for later shuffling
|
80 |
+
if self.seed_dataset is None:
|
81 |
+
self.seed_dataset = np.random.get_state()[1][0]
|
82 |
+
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
|
83 |
+
|
84 |
+
# normalize text
|
85 |
+
if normalize_text:
|
86 |
+
text_normalizer = TextNormalizer()
|
87 |
+
partial_normalize_function = partial(
|
88 |
+
normalize_function,
|
89 |
+
text_column=self.text_column,
|
90 |
+
text_normalizer=text_normalizer,
|
91 |
+
)
|
92 |
+
for ds in ["train_dataset", "eval_dataset"]:
|
93 |
+
if hasattr(self, ds):
|
94 |
+
setattr(
|
95 |
+
self,
|
96 |
+
ds,
|
97 |
+
(
|
98 |
+
getattr(self, ds).map(partial_normalize_function)
|
99 |
+
if self.streaming
|
100 |
+
else getattr(self, ds).map(
|
101 |
+
partial_normalize_function,
|
102 |
+
num_proc=self.preprocessing_num_workers,
|
103 |
+
load_from_cache_file=not self.overwrite_cache,
|
104 |
+
desc="Normalizing datasets",
|
105 |
+
)
|
106 |
+
),
|
107 |
+
)
|
108 |
+
|
109 |
+
# preprocess
|
110 |
+
partial_preprocess_function = partial(
|
111 |
+
preprocess_function,
|
112 |
+
tokenizer=tokenizer,
|
113 |
+
text_column=self.text_column,
|
114 |
+
encoding_column=self.encoding_column,
|
115 |
+
max_source_length=self.max_source_length,
|
116 |
+
decoder_start_token_id=decoder_start_token_id,
|
117 |
+
)
|
118 |
+
for ds in ["train_dataset", "eval_dataset"]:
|
119 |
+
if hasattr(self, ds):
|
120 |
+
setattr(
|
121 |
+
self,
|
122 |
+
ds,
|
123 |
+
(
|
124 |
+
getattr(self, ds).map(
|
125 |
+
partial_preprocess_function,
|
126 |
+
batched=True,
|
127 |
+
)
|
128 |
+
if self.streaming
|
129 |
+
else getattr(self, ds).map(
|
130 |
+
partial_preprocess_function,
|
131 |
+
batched=True,
|
132 |
+
remove_columns=getattr(ds, "column_names"),
|
133 |
+
num_proc=self.preprocessing_num_workers,
|
134 |
+
load_from_cache_file=not self.overwrite_cache,
|
135 |
+
desc="Preprocessing datasets",
|
136 |
+
)
|
137 |
+
),
|
138 |
+
)
|
139 |
+
|
140 |
+
def dataloader(self, split, batch_size, epoch=None):
|
141 |
+
def _dataloader_datasets_non_streaming(
|
142 |
+
dataset: Dataset,
|
143 |
+
batch_size: int,
|
144 |
+
rng: jax.random.PRNGKey = None,
|
145 |
+
):
|
146 |
+
"""
|
147 |
+
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
148 |
+
Shuffle batches if `shuffle` is `True`.
|
149 |
+
"""
|
150 |
+
steps_per_epoch = len(dataset) // batch_size
|
151 |
+
|
152 |
+
if rng is not None:
|
153 |
+
batch_idx = jax.random.permutation(rng, len(dataset))
|
154 |
+
else:
|
155 |
+
batch_idx = jnp.arange(len(dataset))
|
156 |
+
|
157 |
+
batch_idx = batch_idx[
|
158 |
+
: steps_per_epoch * batch_size
|
159 |
+
] # Skip incomplete batch.
|
160 |
+
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
161 |
+
|
162 |
+
for idx in batch_idx:
|
163 |
+
batch = dataset[idx]
|
164 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
165 |
+
batch = shard(batch)
|
166 |
+
yield batch
|
167 |
+
|
168 |
+
def _dataloader_datasets_streaming(dataset: Dataset, batch_size: int):
|
169 |
+
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
170 |
+
batch = {k: [] for k in keys}
|
171 |
+
for item in dataset:
|
172 |
+
for k, v in item.items():
|
173 |
+
batch[k].append(v)
|
174 |
+
if len(batch[keys[0]]) == batch_size:
|
175 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
176 |
+
batch = shard(batch)
|
177 |
+
yield batch
|
178 |
+
batch = {k: [] for k in keys}
|
179 |
+
|
180 |
+
if split == "train":
|
181 |
+
ds = self.train_dataset
|
182 |
+
elif split == "eval":
|
183 |
+
ds = self.eval_dataset
|
184 |
+
else:
|
185 |
+
raise ValueError(f'split must be "train" or "eval", got {split}')
|
186 |
+
|
187 |
+
if self.streaming:
|
188 |
+
if split == "train":
|
189 |
+
ds.set_epoch(epoch)
|
190 |
+
return _dataloader_datasets_streaming(ds, batch_size)
|
191 |
+
else:
|
192 |
+
if split == "train":
|
193 |
+
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
194 |
+
return _dataloader_datasets_non_streaming(ds, batch_size, input_rng)
|
195 |
+
|
196 |
+
@property
|
197 |
+
def length(self):
|
198 |
+
len_train_dataset, len_eval_dataset = None, None
|
199 |
+
if self.streaming:
|
200 |
+
# we don't know the length, let's just assume max_samples if defined
|
201 |
+
if self.max_train_samples is not None:
|
202 |
+
len_train_dataset = self.max_train_samples
|
203 |
+
if self.max_eval_samples is not None:
|
204 |
+
len_eval_dataset = self.max_eval_samples
|
205 |
+
else:
|
206 |
+
len_train_dataset = (
|
207 |
+
len(self.train_dataset) if hasattr(self, "train_dataset") else None
|
208 |
+
)
|
209 |
+
len_eval_dataset = (
|
210 |
+
len(self.eval_dataset) if hasattr(self, "eval_dataset") else None
|
211 |
+
)
|
212 |
+
return len_train_dataset, len_eval_dataset
|
213 |
+
|
214 |
+
|
215 |
+
def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
|
216 |
+
"""
|
217 |
+
Shift input ids one token to the right.
|
218 |
+
"""
|
219 |
+
shifted_input_ids = np.zeros(input_ids.shape)
|
220 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
221 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
222 |
+
return shifted_input_ids
|
223 |
+
|
224 |
+
|
225 |
+
def normalize_function(example, text_column, text_normalizer):
|
226 |
+
example[text_column] = text_normalizer(example[text_column])
|
227 |
+
return example
|
228 |
+
|
229 |
+
|
230 |
+
def preprocess_function(
|
231 |
+
examples,
|
232 |
+
tokenizer,
|
233 |
+
text_column,
|
234 |
+
encoding_column,
|
235 |
+
max_source_length,
|
236 |
+
decoder_start_token_id,
|
237 |
+
):
|
238 |
+
inputs = examples[text_column]
|
239 |
+
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
240 |
+
model_inputs = tokenizer(
|
241 |
+
inputs,
|
242 |
+
max_length=max_source_length,
|
243 |
+
padding="max_length",
|
244 |
+
truncation=True,
|
245 |
+
return_tensors="np",
|
246 |
+
)
|
247 |
+
|
248 |
+
# set up targets
|
249 |
+
# Note: labels correspond to our target indices
|
250 |
+
# decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
|
251 |
+
labels = examples[encoding_column]
|
252 |
+
labels = np.asarray(labels)
|
253 |
+
|
254 |
+
# We need the labels, in addition to the decoder_input_ids, for the compute_loss function
|
255 |
+
model_inputs["labels"] = labels
|
256 |
+
|
257 |
+
# In our case, this prepends the bos token and removes the last one
|
258 |
+
decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id)
|
259 |
+
model_inputs["decoder_input_ids"] = decoder_input_ids
|
260 |
+
|
261 |
+
return model_inputs
|
dalle_mini/dataset.py
DELETED
@@ -1,122 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
An image-caption dataset dataloader.
|
3 |
-
Luke Melas-Kyriazi, 2021
|
4 |
-
"""
|
5 |
-
import warnings
|
6 |
-
from typing import Optional, Callable
|
7 |
-
from pathlib import Path
|
8 |
-
import numpy as np
|
9 |
-
import torch
|
10 |
-
import pandas as pd
|
11 |
-
from torch.utils.data import Dataset
|
12 |
-
from torchvision.datasets.folder import default_loader
|
13 |
-
from PIL import ImageFile
|
14 |
-
from PIL.Image import DecompressionBombWarning
|
15 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
16 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
17 |
-
warnings.filterwarnings("ignore", category=DecompressionBombWarning)
|
18 |
-
|
19 |
-
|
20 |
-
class CaptionDataset(Dataset):
|
21 |
-
"""
|
22 |
-
A PyTorch Dataset class for (image, texts) tasks. Note that this dataset
|
23 |
-
returns the raw text rather than tokens. This is done on purpose, because
|
24 |
-
it's easy to tokenize a batch of text after loading it from this dataset.
|
25 |
-
"""
|
26 |
-
|
27 |
-
def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None,
|
28 |
-
image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
|
29 |
-
include_captions: bool = True):
|
30 |
-
"""
|
31 |
-
:param images_root: folder where images are stored
|
32 |
-
:param captions_path: path to csv that maps image filenames to captions
|
33 |
-
:param image_transform: image transform pipeline
|
34 |
-
:param text_transform: image transform pipeline
|
35 |
-
:param image_transform_type: image transform type, either `torchvision` or `albumentations`
|
36 |
-
:param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
|
37 |
-
"""
|
38 |
-
|
39 |
-
# Base path for images
|
40 |
-
self.images_root = Path(images_root)
|
41 |
-
|
42 |
-
# Load captions as DataFrame
|
43 |
-
self.captions = pd.read_csv(captions_path, delimiter='\t', header=0)
|
44 |
-
self.captions['image_file'] = self.captions['image_file'].astype(str)
|
45 |
-
|
46 |
-
# PyTorch transformation pipeline for the image (normalizing, etc.)
|
47 |
-
self.text_transform = text_transform
|
48 |
-
self.image_transform = image_transform
|
49 |
-
self.image_transform_type = image_transform_type.lower()
|
50 |
-
assert self.image_transform_type in ['torchvision', 'albumentations']
|
51 |
-
|
52 |
-
# Total number of datapoints
|
53 |
-
self.size = len(self.captions)
|
54 |
-
|
55 |
-
# Return image+captions or just images
|
56 |
-
self.include_captions = include_captions
|
57 |
-
|
58 |
-
def verify_that_all_images_exist(self):
|
59 |
-
for image_file in self.captions['image_file']:
|
60 |
-
p = self.images_root / image_file
|
61 |
-
if not p.is_file():
|
62 |
-
print(f'file does not exist: {p}')
|
63 |
-
|
64 |
-
def _get_raw_image(self, i):
|
65 |
-
image_file = self.captions.iloc[i]['image_file']
|
66 |
-
image_path = self.images_root / image_file
|
67 |
-
image = default_loader(image_path)
|
68 |
-
return image
|
69 |
-
|
70 |
-
def _get_raw_text(self, i):
|
71 |
-
return self.captions.iloc[i]['caption']
|
72 |
-
|
73 |
-
def __getitem__(self, i):
|
74 |
-
image = self._get_raw_image(i)
|
75 |
-
caption = self._get_raw_text(i)
|
76 |
-
if self.image_transform is not None:
|
77 |
-
if self.image_transform_type == 'torchvision':
|
78 |
-
image = self.image_transform(image)
|
79 |
-
elif self.image_transform_type == 'albumentations':
|
80 |
-
image = self.image_transform(image=np.array(image))['image']
|
81 |
-
else:
|
82 |
-
raise NotImplementedError(f"{self.image_transform_type=}")
|
83 |
-
return {'image': image, 'text': caption} if self.include_captions else image
|
84 |
-
|
85 |
-
def __len__(self):
|
86 |
-
return self.size
|
87 |
-
|
88 |
-
|
89 |
-
if __name__ == "__main__":
|
90 |
-
import albumentations as A
|
91 |
-
from albumentations.pytorch import ToTensorV2
|
92 |
-
from transformers import AutoTokenizer
|
93 |
-
|
94 |
-
# Paths
|
95 |
-
images_root = './images'
|
96 |
-
captions_path = './images-list-clean.tsv'
|
97 |
-
|
98 |
-
# Create transforms
|
99 |
-
tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
|
100 |
-
def tokenize(text):
|
101 |
-
return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length')
|
102 |
-
image_transform = A.Compose([
|
103 |
-
A.Resize(256, 256), A.CenterCrop(256, 256),
|
104 |
-
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
|
105 |
-
|
106 |
-
# Create dataset
|
107 |
-
dataset = CaptionDataset(
|
108 |
-
images_root=images_root,
|
109 |
-
captions_path=captions_path,
|
110 |
-
image_transform=image_transform,
|
111 |
-
text_transform=tokenize,
|
112 |
-
image_transform_type='albumentations')
|
113 |
-
|
114 |
-
# Create dataloader
|
115 |
-
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
|
116 |
-
batch = next(iter(dataloader))
|
117 |
-
print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()})
|
118 |
-
|
119 |
-
# # (Optional) Check that all the images exist
|
120 |
-
# dataset = CaptionDataset(images_root=images_root, captions_path=captions_path)
|
121 |
-
# dataset.verify_that_all_images_exist()
|
122 |
-
# print('Done')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dalle_mini/model.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import flax.linen as nn
|
2 |
+
import jax
|
3 |
+
from transformers import BartConfig
|
4 |
+
from transformers.models.bart.modeling_flax_bart import (
|
5 |
+
FlaxBartDecoder,
|
6 |
+
FlaxBartEncoder,
|
7 |
+
FlaxBartForConditionalGeneration,
|
8 |
+
FlaxBartForConditionalGenerationModule,
|
9 |
+
FlaxBartModule,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
class CustomFlaxBartModule(FlaxBartModule):
|
14 |
+
def setup(self):
|
15 |
+
# we keep shared to easily load pre-trained weights
|
16 |
+
self.shared = nn.Embed(
|
17 |
+
self.config.vocab_size,
|
18 |
+
self.config.d_model,
|
19 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
20 |
+
)
|
21 |
+
# a separate embedding is used for the decoder
|
22 |
+
self.decoder_embed = nn.Embed(
|
23 |
+
self.config.image_vocab_size + 1,
|
24 |
+
self.config.d_model,
|
25 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
26 |
+
)
|
27 |
+
self.encoder = FlaxBartEncoder(
|
28 |
+
self.config, dtype=self.dtype, embed_tokens=self.shared
|
29 |
+
)
|
30 |
+
|
31 |
+
# the decoder has a different config
|
32 |
+
# TODO: should not be needed once we have custom config/module
|
33 |
+
decoder_config = BartConfig(self.config.to_dict())
|
34 |
+
decoder_config.max_position_embeddings = (
|
35 |
+
self.config.image_length + 1 # image tokens + BOS
|
36 |
+
)
|
37 |
+
decoder_config.vocab_size = self.config.image_vocab_size + 1
|
38 |
+
self.decoder = FlaxBartDecoder(
|
39 |
+
decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
class CustomFlaxBartForConditionalGenerationModule(
|
44 |
+
FlaxBartForConditionalGenerationModule
|
45 |
+
):
|
46 |
+
def setup(self):
|
47 |
+
# set default config
|
48 |
+
self.config.normalize_text = getattr(self.config, "normalize_text", False)
|
49 |
+
self.config.image_length = getattr(self.config, "image_length", 256)
|
50 |
+
self.config.image_vocab_size = getattr(self.config, "image_vocab_size", 16384)
|
51 |
+
|
52 |
+
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
53 |
+
self.lm_head = nn.Dense(
|
54 |
+
self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
|
55 |
+
use_bias=False,
|
56 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
57 |
+
)
|
58 |
+
self.final_logits_bias = self.param(
|
59 |
+
"final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
64 |
+
module_class = CustomFlaxBartForConditionalGenerationModule
|
dalle_mini/text.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for processing text.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import html
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
import re
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import ftfy
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
from unidecode import unidecode
|
14 |
+
|
15 |
+
# based on wiki word occurence
|
16 |
+
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
|
17 |
+
temp_token = "xtokx" # avoid repeating chars
|
18 |
+
|
19 |
+
|
20 |
+
class HashtagProcessor:
|
21 |
+
# Adapted from wordninja library
|
22 |
+
# We use our wikipedia word count + a good heuristic to make it work
|
23 |
+
def __init__(self):
|
24 |
+
wiki_word_frequency = hf_hub_download(
|
25 |
+
"dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
|
26 |
+
)
|
27 |
+
self._word_cost = (
|
28 |
+
l.split()[0] for l in Path(wiki_word_frequency).read_text().splitlines()
|
29 |
+
)
|
30 |
+
self._word_cost = {
|
31 |
+
str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
|
32 |
+
}
|
33 |
+
self._max_word = max(len(x) for x in self._word_cost.keys())
|
34 |
+
self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
|
35 |
+
|
36 |
+
def __call__(self, s):
|
37 |
+
"""Uses dynamic programming to infer the location of spaces in a string without spaces."""
|
38 |
+
l = [self._split(x) for x in self._SPLIT_RE.split(s)]
|
39 |
+
return " ".join([item for sublist in l for item in sublist])
|
40 |
+
|
41 |
+
def _split(self, s):
|
42 |
+
# Find the best match for the i first characters, assuming cost has
|
43 |
+
# been built for the i-1 first characters.
|
44 |
+
# Returns a pair (match_cost, match_length).
|
45 |
+
def best_match(i):
|
46 |
+
candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
|
47 |
+
return min(
|
48 |
+
(c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1)
|
49 |
+
for k, c in candidates
|
50 |
+
)
|
51 |
+
|
52 |
+
# Build the cost array
|
53 |
+
cost = [0]
|
54 |
+
for i in range(1, len(s) + 1):
|
55 |
+
c, k = best_match(i)
|
56 |
+
cost.append(c)
|
57 |
+
|
58 |
+
# Backtrack to recover the minimal-cost string.
|
59 |
+
out = []
|
60 |
+
i = len(s)
|
61 |
+
while i > 0:
|
62 |
+
c, k = best_match(i)
|
63 |
+
assert c == cost[i]
|
64 |
+
newToken = True
|
65 |
+
if not s[i - k : i] == "'": # ignore a lone apostrophe
|
66 |
+
if len(out) > 0:
|
67 |
+
# re-attach split 's and split digits
|
68 |
+
if out[-1] == "'s" or (
|
69 |
+
s[i - 1].isdigit() and out[-1][0].isdigit()
|
70 |
+
): # digit followed by digit
|
71 |
+
out[-1] = (
|
72 |
+
s[i - k : i] + out[-1]
|
73 |
+
) # combine current token with previous token
|
74 |
+
newToken = False
|
75 |
+
|
76 |
+
if newToken:
|
77 |
+
out.append(s[i - k : i])
|
78 |
+
|
79 |
+
i -= k
|
80 |
+
|
81 |
+
return reversed(out)
|
82 |
+
|
83 |
+
|
84 |
+
def replace_person_token(t):
|
85 |
+
"Used for CC12M"
|
86 |
+
t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
|
87 |
+
while "<person>" in t:
|
88 |
+
t = t.replace(
|
89 |
+
"<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1
|
90 |
+
)
|
91 |
+
return t
|
92 |
+
|
93 |
+
|
94 |
+
def fix_html(t):
|
95 |
+
# from OpenAI CLIP
|
96 |
+
return html.unescape(html.unescape(t))
|
97 |
+
|
98 |
+
|
99 |
+
def replace_punctuation_with_commas(t):
|
100 |
+
return re.sub("[()[\].,|:;?!=+~\-\/{}]", ",", t)
|
101 |
+
|
102 |
+
|
103 |
+
def simplify_quotes(t):
|
104 |
+
return re.sub("""['"`]""", ' " ', t)
|
105 |
+
|
106 |
+
|
107 |
+
def merge_quotes(t):
|
108 |
+
return re.sub('(\s*"+\s*)+', ' " ', t)
|
109 |
+
|
110 |
+
|
111 |
+
def remove_comma_numbers(t):
|
112 |
+
def _f(t):
|
113 |
+
return re.sub("(\d),(\d{3})", r"\1\2", t)
|
114 |
+
|
115 |
+
return _f(_f(t))
|
116 |
+
|
117 |
+
|
118 |
+
def pre_process_dot_numbers(t):
|
119 |
+
return re.sub("(\w)\.(\w)", fr"\1{temp_token}dot{temp_token}\2", t)
|
120 |
+
|
121 |
+
|
122 |
+
def post_process_dot_numbers(t):
|
123 |
+
return re.sub(f"{temp_token}dot{temp_token}", ".", t)
|
124 |
+
|
125 |
+
|
126 |
+
def pre_process_quotes(t):
|
127 |
+
# allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
|
128 |
+
return re.sub(
|
129 |
+
r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", fr"{temp_token}quote{temp_token}", t
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
def post_process_quotes(t):
|
134 |
+
return re.sub(f"{temp_token}quote{temp_token}", "'", t)
|
135 |
+
|
136 |
+
|
137 |
+
def pre_process_dates(t):
|
138 |
+
return re.sub("(\d)/(\d)", fr"\1{temp_token}slash{temp_token}\2", t)
|
139 |
+
|
140 |
+
|
141 |
+
def post_process_dates(t):
|
142 |
+
return re.sub(f"{temp_token}slash{temp_token}", "/", t)
|
143 |
+
|
144 |
+
|
145 |
+
def merge_commas(t):
|
146 |
+
return re.sub("(\s*,+\s*)+", ", ", t)
|
147 |
+
|
148 |
+
|
149 |
+
def add_space_after_commas(t):
|
150 |
+
return re.sub(",", ", ", t)
|
151 |
+
|
152 |
+
|
153 |
+
def handle_special_chars(t):
|
154 |
+
"Handle special characters"
|
155 |
+
# replace "-" with a space when between words without space
|
156 |
+
t = re.sub("(\w)-(\w)", r"\1 \2", t)
|
157 |
+
# always add space around some characters
|
158 |
+
return re.sub("([%&\/$*])", r" \1 ", t)
|
159 |
+
|
160 |
+
|
161 |
+
def expand_hashtags(t, hashtag_processor):
|
162 |
+
"Remove # and try to split words"
|
163 |
+
return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
|
164 |
+
|
165 |
+
|
166 |
+
_re_ignore_chars = r"[_#\\]"
|
167 |
+
|
168 |
+
|
169 |
+
def ignore_chars(t):
|
170 |
+
"Ignore useless characters"
|
171 |
+
return re.sub(_re_ignore_chars, " ", t)
|
172 |
+
|
173 |
+
|
174 |
+
def remove_extra_spaces(t):
|
175 |
+
"Remove extra spaces (including \t and \n)"
|
176 |
+
return re.sub("\s+", " ", t)
|
177 |
+
|
178 |
+
|
179 |
+
def remove_repeating_chars(t):
|
180 |
+
"If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
|
181 |
+
return re.sub(r"(\D)(\1{3,})", r"\1", t)
|
182 |
+
|
183 |
+
|
184 |
+
def remove_urls(t):
|
185 |
+
return re.sub(r"http\S+", "", t)
|
186 |
+
|
187 |
+
|
188 |
+
def remove_html_tags(t):
|
189 |
+
return re.sub("<[^<]+?>", "", t)
|
190 |
+
|
191 |
+
|
192 |
+
def remove_first_last_commas(t):
|
193 |
+
t = t.strip()
|
194 |
+
t = t[:-1] if t and t[-1] == "," else t
|
195 |
+
t = t[1:] if t and t[0] == "," else t
|
196 |
+
return t.strip()
|
197 |
+
|
198 |
+
|
199 |
+
def remove_wiki_ref(t):
|
200 |
+
t = re.sub(r"\A\s*\[\d+\]", "", t)
|
201 |
+
return re.sub(r"\[\d+\]\s*\Z", "", t)
|
202 |
+
|
203 |
+
|
204 |
+
class TextNormalizer:
|
205 |
+
"Normalize text"
|
206 |
+
|
207 |
+
def __init__(self):
|
208 |
+
self._hashtag_processor = HashtagProcessor()
|
209 |
+
|
210 |
+
def __call__(self, t):
|
211 |
+
# fix some characters
|
212 |
+
t = ftfy.fix_text(t)
|
213 |
+
# fix html
|
214 |
+
t = fix_html(t)
|
215 |
+
# decode and simplify text: see unidecode library
|
216 |
+
t = unidecode(t)
|
217 |
+
# lower case
|
218 |
+
t = t.lower()
|
219 |
+
# replace <PERSON> (for CC12M)
|
220 |
+
t = replace_person_token(t)
|
221 |
+
# remove wiki reference (for WIT)
|
222 |
+
t = remove_wiki_ref(t)
|
223 |
+
# remove html tags
|
224 |
+
t = remove_html_tags(t)
|
225 |
+
# remove urls
|
226 |
+
t = remove_urls(t)
|
227 |
+
# remove commas in numbers
|
228 |
+
t = remove_comma_numbers(t)
|
229 |
+
# handle dots in numbers and quotes - Part 1
|
230 |
+
t = pre_process_dot_numbers(t)
|
231 |
+
t = pre_process_quotes(t)
|
232 |
+
t = pre_process_dates(t)
|
233 |
+
# handle special characters
|
234 |
+
t = handle_special_chars(t)
|
235 |
+
# handle hashtags
|
236 |
+
t = expand_hashtags(t, self._hashtag_processor)
|
237 |
+
# ignore useless characters
|
238 |
+
t = ignore_chars(t)
|
239 |
+
# simplify quotes
|
240 |
+
t = simplify_quotes(t)
|
241 |
+
# all punctuation becomes commas
|
242 |
+
t = replace_punctuation_with_commas(t)
|
243 |
+
# handle dots in numbers and quotes - Part 2
|
244 |
+
t = post_process_dot_numbers(t)
|
245 |
+
t = post_process_quotes(t)
|
246 |
+
t = post_process_dates(t)
|
247 |
+
# handle repeating characters
|
248 |
+
t = remove_repeating_chars(t)
|
249 |
+
# merge quotes
|
250 |
+
t = merge_quotes(t)
|
251 |
+
# merge commas
|
252 |
+
t = merge_commas(t)
|
253 |
+
# remove multiple spaces
|
254 |
+
t = remove_extra_spaces(t)
|
255 |
+
# remove first and last comma
|
256 |
+
t = remove_first_last_commas(t)
|
257 |
+
# always start with a space
|
258 |
+
return f" {t}"
|
dalle_mini/vqgan_jax/__init__.py
DELETED
File without changes
|
dalle_mini/vqgan_jax/configuration_vqgan.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
from typing import Tuple
|
2 |
-
|
3 |
-
from transformers import PretrainedConfig
|
4 |
-
|
5 |
-
|
6 |
-
class VQGANConfig(PretrainedConfig):
|
7 |
-
def __init__(
|
8 |
-
self,
|
9 |
-
ch: int = 128,
|
10 |
-
out_ch: int = 3,
|
11 |
-
in_channels: int = 3,
|
12 |
-
num_res_blocks: int = 2,
|
13 |
-
resolution: int = 256,
|
14 |
-
z_channels: int = 256,
|
15 |
-
ch_mult: Tuple = (1, 1, 2, 2, 4),
|
16 |
-
attn_resolutions: int = (16,),
|
17 |
-
n_embed: int = 1024,
|
18 |
-
embed_dim: int = 256,
|
19 |
-
dropout: float = 0.0,
|
20 |
-
double_z: bool = False,
|
21 |
-
resamp_with_conv: bool = True,
|
22 |
-
give_pre_end: bool = False,
|
23 |
-
**kwargs,
|
24 |
-
):
|
25 |
-
super().__init__(**kwargs)
|
26 |
-
self.ch = ch
|
27 |
-
self.out_ch = out_ch
|
28 |
-
self.in_channels = in_channels
|
29 |
-
self.num_res_blocks = num_res_blocks
|
30 |
-
self.resolution = resolution
|
31 |
-
self.z_channels = z_channels
|
32 |
-
self.ch_mult = list(ch_mult)
|
33 |
-
self.attn_resolutions = list(attn_resolutions)
|
34 |
-
self.n_embed = n_embed
|
35 |
-
self.embed_dim = embed_dim
|
36 |
-
self.dropout = dropout
|
37 |
-
self.double_z = double_z
|
38 |
-
self.resamp_with_conv = resamp_with_conv
|
39 |
-
self.give_pre_end = give_pre_end
|
40 |
-
self.num_resolutions = len(ch_mult)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dalle_mini/vqgan_jax/convert_pt_model_to_jax.py
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
|
3 |
-
import jax.numpy as jnp
|
4 |
-
from flax.traverse_util import flatten_dict, unflatten_dict
|
5 |
-
|
6 |
-
import torch
|
7 |
-
|
8 |
-
from modeling_flax_vqgan import VQModel
|
9 |
-
from configuration_vqgan import VQGANConfig
|
10 |
-
|
11 |
-
|
12 |
-
regex = r"\w+[.]\d+"
|
13 |
-
|
14 |
-
|
15 |
-
def rename_key(key):
|
16 |
-
pats = re.findall(regex, key)
|
17 |
-
for pat in pats:
|
18 |
-
key = key.replace(pat, "_".join(pat.split(".")))
|
19 |
-
return key
|
20 |
-
|
21 |
-
|
22 |
-
# Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
|
23 |
-
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
24 |
-
# convert pytorch tensor to numpy
|
25 |
-
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
26 |
-
|
27 |
-
random_flax_state_dict = flatten_dict(flax_model.params)
|
28 |
-
flax_state_dict = {}
|
29 |
-
|
30 |
-
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
|
31 |
-
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
32 |
-
)
|
33 |
-
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
|
34 |
-
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
35 |
-
)
|
36 |
-
|
37 |
-
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
38 |
-
for pt_key, pt_tensor in pt_state_dict.items():
|
39 |
-
pt_tuple_key = tuple(pt_key.split("."))
|
40 |
-
|
41 |
-
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
|
42 |
-
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
|
43 |
-
|
44 |
-
if remove_base_model_prefix and has_base_model_prefix:
|
45 |
-
pt_tuple_key = pt_tuple_key[1:]
|
46 |
-
elif add_base_model_prefix and require_base_model_prefix:
|
47 |
-
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
|
48 |
-
|
49 |
-
# Correctly rename weight parameters
|
50 |
-
if (
|
51 |
-
"norm" in pt_key
|
52 |
-
and (pt_tuple_key[-1] == "bias")
|
53 |
-
and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict)
|
54 |
-
):
|
55 |
-
pt_tensor = pt_tensor[None, None, None, :]
|
56 |
-
elif (
|
57 |
-
"norm" in pt_key
|
58 |
-
and (pt_tuple_key[-1] == "bias")
|
59 |
-
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
|
60 |
-
):
|
61 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
62 |
-
pt_tensor = pt_tensor[None, None, None, :]
|
63 |
-
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
64 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
65 |
-
pt_tensor = pt_tensor[None, None, None, :]
|
66 |
-
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
67 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
68 |
-
elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
|
69 |
-
# conv layer
|
70 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
71 |
-
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
72 |
-
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
|
73 |
-
# linear layer
|
74 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
75 |
-
pt_tensor = pt_tensor.T
|
76 |
-
elif pt_tuple_key[-1] == "gamma":
|
77 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
78 |
-
elif pt_tuple_key[-1] == "beta":
|
79 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
80 |
-
|
81 |
-
if pt_tuple_key in random_flax_state_dict:
|
82 |
-
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
|
83 |
-
raise ValueError(
|
84 |
-
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
85 |
-
f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
|
86 |
-
)
|
87 |
-
|
88 |
-
# also add unexpected weight so that warning is thrown
|
89 |
-
flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
|
90 |
-
|
91 |
-
return unflatten_dict(flax_state_dict)
|
92 |
-
|
93 |
-
|
94 |
-
def convert_model(config_path, pt_state_dict_path, save_path):
|
95 |
-
config = VQGANConfig.from_pretrained(config_path)
|
96 |
-
model = VQModel(config)
|
97 |
-
|
98 |
-
state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
|
99 |
-
keys = list(state_dict.keys())
|
100 |
-
for key in keys:
|
101 |
-
if key.startswith("loss"):
|
102 |
-
state_dict.pop(key)
|
103 |
-
continue
|
104 |
-
renamed_key = rename_key(key)
|
105 |
-
state_dict[renamed_key] = state_dict.pop(key)
|
106 |
-
|
107 |
-
state = convert_pytorch_state_dict_to_flax(state_dict, model)
|
108 |
-
model.params = unflatten_dict(state)
|
109 |
-
model.save_pretrained(save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dalle_mini/vqgan_jax/modeling_flax_vqgan.py
DELETED
@@ -1,609 +0,0 @@
|
|
1 |
-
# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
|
2 |
-
|
3 |
-
from functools import partial
|
4 |
-
from typing import Tuple
|
5 |
-
import math
|
6 |
-
|
7 |
-
import jax
|
8 |
-
import jax.numpy as jnp
|
9 |
-
import numpy as np
|
10 |
-
import flax.linen as nn
|
11 |
-
from flax.core.frozen_dict import FrozenDict
|
12 |
-
|
13 |
-
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
14 |
-
|
15 |
-
from .configuration_vqgan import VQGANConfig
|
16 |
-
|
17 |
-
|
18 |
-
class Upsample(nn.Module):
|
19 |
-
in_channels: int
|
20 |
-
with_conv: bool
|
21 |
-
dtype: jnp.dtype = jnp.float32
|
22 |
-
|
23 |
-
def setup(self):
|
24 |
-
if self.with_conv:
|
25 |
-
self.conv = nn.Conv(
|
26 |
-
self.in_channels,
|
27 |
-
kernel_size=(3, 3),
|
28 |
-
strides=(1, 1),
|
29 |
-
padding=((1, 1), (1, 1)),
|
30 |
-
dtype=self.dtype,
|
31 |
-
)
|
32 |
-
|
33 |
-
def __call__(self, hidden_states):
|
34 |
-
batch, height, width, channels = hidden_states.shape
|
35 |
-
hidden_states = jax.image.resize(
|
36 |
-
hidden_states,
|
37 |
-
shape=(batch, height * 2, width * 2, channels),
|
38 |
-
method="nearest",
|
39 |
-
)
|
40 |
-
if self.with_conv:
|
41 |
-
hidden_states = self.conv(hidden_states)
|
42 |
-
return hidden_states
|
43 |
-
|
44 |
-
|
45 |
-
class Downsample(nn.Module):
|
46 |
-
in_channels: int
|
47 |
-
with_conv: bool
|
48 |
-
dtype: jnp.dtype = jnp.float32
|
49 |
-
|
50 |
-
def setup(self):
|
51 |
-
if self.with_conv:
|
52 |
-
self.conv = nn.Conv(
|
53 |
-
self.in_channels,
|
54 |
-
kernel_size=(3, 3),
|
55 |
-
strides=(2, 2),
|
56 |
-
padding="VALID",
|
57 |
-
dtype=self.dtype,
|
58 |
-
)
|
59 |
-
|
60 |
-
def __call__(self, hidden_states):
|
61 |
-
if self.with_conv:
|
62 |
-
pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
|
63 |
-
hidden_states = jnp.pad(hidden_states, pad_width=pad)
|
64 |
-
hidden_states = self.conv(hidden_states)
|
65 |
-
else:
|
66 |
-
hidden_states = nn.avg_pool(hidden_states, window_shape=(2, 2), strides=(2, 2), padding="VALID")
|
67 |
-
return hidden_states
|
68 |
-
|
69 |
-
|
70 |
-
class ResnetBlock(nn.Module):
|
71 |
-
in_channels: int
|
72 |
-
out_channels: int = None
|
73 |
-
use_conv_shortcut: bool = False
|
74 |
-
temb_channels: int = 512
|
75 |
-
dropout_prob: float = 0.0
|
76 |
-
dtype: jnp.dtype = jnp.float32
|
77 |
-
|
78 |
-
def setup(self):
|
79 |
-
self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
|
80 |
-
|
81 |
-
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
82 |
-
self.conv1 = nn.Conv(
|
83 |
-
self.out_channels_,
|
84 |
-
kernel_size=(3, 3),
|
85 |
-
strides=(1, 1),
|
86 |
-
padding=((1, 1), (1, 1)),
|
87 |
-
dtype=self.dtype,
|
88 |
-
)
|
89 |
-
|
90 |
-
if self.temb_channels:
|
91 |
-
self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype)
|
92 |
-
|
93 |
-
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
94 |
-
self.dropout = nn.Dropout(self.dropout_prob)
|
95 |
-
self.conv2 = nn.Conv(
|
96 |
-
self.out_channels_,
|
97 |
-
kernel_size=(3, 3),
|
98 |
-
strides=(1, 1),
|
99 |
-
padding=((1, 1), (1, 1)),
|
100 |
-
dtype=self.dtype,
|
101 |
-
)
|
102 |
-
|
103 |
-
if self.in_channels != self.out_channels_:
|
104 |
-
if self.use_conv_shortcut:
|
105 |
-
self.conv_shortcut = nn.Conv(
|
106 |
-
self.out_channels_,
|
107 |
-
kernel_size=(3, 3),
|
108 |
-
strides=(1, 1),
|
109 |
-
padding=((1, 1), (1, 1)),
|
110 |
-
dtype=self.dtype,
|
111 |
-
)
|
112 |
-
else:
|
113 |
-
self.nin_shortcut = nn.Conv(
|
114 |
-
self.out_channels_,
|
115 |
-
kernel_size=(1, 1),
|
116 |
-
strides=(1, 1),
|
117 |
-
padding="VALID",
|
118 |
-
dtype=self.dtype,
|
119 |
-
)
|
120 |
-
|
121 |
-
def __call__(self, hidden_states, temb=None, deterministic: bool = True):
|
122 |
-
residual = hidden_states
|
123 |
-
hidden_states = self.norm1(hidden_states)
|
124 |
-
hidden_states = nn.swish(hidden_states)
|
125 |
-
hidden_states = self.conv1(hidden_states)
|
126 |
-
|
127 |
-
if temb is not None:
|
128 |
-
hidden_states = hidden_states + self.temb_proj(nn.swish(temb))[:, :, None, None] # TODO: check shapes
|
129 |
-
|
130 |
-
hidden_states = self.norm2(hidden_states)
|
131 |
-
hidden_states = nn.swish(hidden_states)
|
132 |
-
hidden_states = self.dropout(hidden_states, deterministic)
|
133 |
-
hidden_states = self.conv2(hidden_states)
|
134 |
-
|
135 |
-
if self.in_channels != self.out_channels_:
|
136 |
-
if self.use_conv_shortcut:
|
137 |
-
residual = self.conv_shortcut(residual)
|
138 |
-
else:
|
139 |
-
residual = self.nin_shortcut(residual)
|
140 |
-
|
141 |
-
return hidden_states + residual
|
142 |
-
|
143 |
-
|
144 |
-
class AttnBlock(nn.Module):
|
145 |
-
in_channels: int
|
146 |
-
dtype: jnp.dtype = jnp.float32
|
147 |
-
|
148 |
-
def setup(self):
|
149 |
-
conv = partial(
|
150 |
-
nn.Conv, self.in_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype
|
151 |
-
)
|
152 |
-
|
153 |
-
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
154 |
-
self.q, self.k, self.v = conv(), conv(), conv()
|
155 |
-
self.proj_out = conv()
|
156 |
-
|
157 |
-
def __call__(self, hidden_states):
|
158 |
-
residual = hidden_states
|
159 |
-
hidden_states = self.norm(hidden_states)
|
160 |
-
|
161 |
-
query = self.q(hidden_states)
|
162 |
-
key = self.k(hidden_states)
|
163 |
-
value = self.v(hidden_states)
|
164 |
-
|
165 |
-
# compute attentions
|
166 |
-
batch, height, width, channels = query.shape
|
167 |
-
query = query.reshape((batch, height * width, channels))
|
168 |
-
key = key.reshape((batch, height * width, channels))
|
169 |
-
attn_weights = jnp.einsum("...qc,...kc->...qk", query, key)
|
170 |
-
attn_weights = attn_weights * (int(channels) ** -0.5)
|
171 |
-
attn_weights = nn.softmax(attn_weights, axis=2)
|
172 |
-
|
173 |
-
## attend to values
|
174 |
-
value = value.reshape((batch, height * width, channels))
|
175 |
-
hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
|
176 |
-
hidden_states = hidden_states.reshape((batch, height, width, channels))
|
177 |
-
|
178 |
-
hidden_states = self.proj_out(hidden_states)
|
179 |
-
hidden_states = hidden_states + residual
|
180 |
-
return hidden_states
|
181 |
-
|
182 |
-
|
183 |
-
class UpsamplingBlock(nn.Module):
|
184 |
-
config: VQGANConfig
|
185 |
-
curr_res: int
|
186 |
-
block_idx: int
|
187 |
-
dtype: jnp.dtype = jnp.float32
|
188 |
-
|
189 |
-
def setup(self):
|
190 |
-
if self.block_idx == self.config.num_resolutions - 1:
|
191 |
-
block_in = self.config.ch * self.config.ch_mult[-1]
|
192 |
-
else:
|
193 |
-
block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1]
|
194 |
-
|
195 |
-
block_out = self.config.ch * self.config.ch_mult[self.block_idx]
|
196 |
-
self.temb_ch = 0
|
197 |
-
|
198 |
-
res_blocks = []
|
199 |
-
attn_blocks = []
|
200 |
-
for _ in range(self.config.num_res_blocks + 1):
|
201 |
-
res_blocks.append(
|
202 |
-
ResnetBlock(
|
203 |
-
block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
|
204 |
-
)
|
205 |
-
)
|
206 |
-
block_in = block_out
|
207 |
-
if self.curr_res in self.config.attn_resolutions:
|
208 |
-
attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
|
209 |
-
|
210 |
-
self.block = res_blocks
|
211 |
-
self.attn = attn_blocks
|
212 |
-
|
213 |
-
self.upsample = None
|
214 |
-
if self.block_idx != 0:
|
215 |
-
self.upsample = Upsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
|
216 |
-
|
217 |
-
def __call__(self, hidden_states, temb=None, deterministic: bool = True):
|
218 |
-
for res_block in self.block:
|
219 |
-
hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
|
220 |
-
for attn_block in self.attn:
|
221 |
-
hidden_states = attn_block(hidden_states)
|
222 |
-
|
223 |
-
if self.upsample is not None:
|
224 |
-
hidden_states = self.upsample(hidden_states)
|
225 |
-
|
226 |
-
return hidden_states
|
227 |
-
|
228 |
-
|
229 |
-
class DownsamplingBlock(nn.Module):
|
230 |
-
config: VQGANConfig
|
231 |
-
curr_res: int
|
232 |
-
block_idx: int
|
233 |
-
dtype: jnp.dtype = jnp.float32
|
234 |
-
|
235 |
-
def setup(self):
|
236 |
-
in_ch_mult = (1,) + tuple(self.config.ch_mult)
|
237 |
-
block_in = self.config.ch * in_ch_mult[self.block_idx]
|
238 |
-
block_out = self.config.ch * self.config.ch_mult[self.block_idx]
|
239 |
-
self.temb_ch = 0
|
240 |
-
|
241 |
-
res_blocks = []
|
242 |
-
attn_blocks = []
|
243 |
-
for _ in range(self.config.num_res_blocks):
|
244 |
-
res_blocks.append(
|
245 |
-
ResnetBlock(
|
246 |
-
block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
|
247 |
-
)
|
248 |
-
)
|
249 |
-
block_in = block_out
|
250 |
-
if self.curr_res in self.config.attn_resolutions:
|
251 |
-
attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
|
252 |
-
|
253 |
-
self.block = res_blocks
|
254 |
-
self.attn = attn_blocks
|
255 |
-
|
256 |
-
self.downsample = None
|
257 |
-
if self.block_idx != self.config.num_resolutions - 1:
|
258 |
-
self.downsample = Downsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
|
259 |
-
|
260 |
-
def __call__(self, hidden_states, temb=None, deterministic: bool = True):
|
261 |
-
for res_block in self.block:
|
262 |
-
hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
|
263 |
-
for attn_block in self.attn:
|
264 |
-
hidden_states = attn_block(hidden_states)
|
265 |
-
|
266 |
-
if self.downsample is not None:
|
267 |
-
hidden_states = self.downsample(hidden_states)
|
268 |
-
|
269 |
-
return hidden_states
|
270 |
-
|
271 |
-
|
272 |
-
class MidBlock(nn.Module):
|
273 |
-
in_channels: int
|
274 |
-
temb_channels: int
|
275 |
-
dropout: float
|
276 |
-
dtype: jnp.dtype = jnp.float32
|
277 |
-
|
278 |
-
def setup(self):
|
279 |
-
self.block_1 = ResnetBlock(
|
280 |
-
self.in_channels,
|
281 |
-
self.in_channels,
|
282 |
-
temb_channels=self.temb_channels,
|
283 |
-
dropout_prob=self.dropout,
|
284 |
-
dtype=self.dtype,
|
285 |
-
)
|
286 |
-
self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype)
|
287 |
-
self.block_2 = ResnetBlock(
|
288 |
-
self.in_channels,
|
289 |
-
self.in_channels,
|
290 |
-
temb_channels=self.temb_channels,
|
291 |
-
dropout_prob=self.dropout,
|
292 |
-
dtype=self.dtype,
|
293 |
-
)
|
294 |
-
|
295 |
-
def __call__(self, hidden_states, temb=None, deterministic: bool = True):
|
296 |
-
hidden_states = self.block_1(hidden_states, temb, deterministic=deterministic)
|
297 |
-
hidden_states = self.attn_1(hidden_states)
|
298 |
-
hidden_states = self.block_2(hidden_states, temb, deterministic=deterministic)
|
299 |
-
return hidden_states
|
300 |
-
|
301 |
-
|
302 |
-
class Encoder(nn.Module):
|
303 |
-
config: VQGANConfig
|
304 |
-
dtype: jnp.dtype = jnp.float32
|
305 |
-
|
306 |
-
def setup(self):
|
307 |
-
self.temb_ch = 0
|
308 |
-
|
309 |
-
# downsampling
|
310 |
-
self.conv_in = nn.Conv(
|
311 |
-
self.config.ch,
|
312 |
-
kernel_size=(3, 3),
|
313 |
-
strides=(1, 1),
|
314 |
-
padding=((1, 1), (1, 1)),
|
315 |
-
dtype=self.dtype,
|
316 |
-
)
|
317 |
-
|
318 |
-
curr_res = self.config.resolution
|
319 |
-
downsample_blocks = []
|
320 |
-
for i_level in range(self.config.num_resolutions):
|
321 |
-
downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
|
322 |
-
|
323 |
-
if i_level != self.config.num_resolutions - 1:
|
324 |
-
curr_res = curr_res // 2
|
325 |
-
self.down = downsample_blocks
|
326 |
-
|
327 |
-
# middle
|
328 |
-
mid_channels = self.config.ch * self.config.ch_mult[-1]
|
329 |
-
self.mid = MidBlock(mid_channels, self.temb_ch, self.config.dropout, dtype=self.dtype)
|
330 |
-
|
331 |
-
# end
|
332 |
-
self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
333 |
-
self.conv_out = nn.Conv(
|
334 |
-
2 * self.config.z_channels if self.config.double_z else self.config.z_channels,
|
335 |
-
kernel_size=(3, 3),
|
336 |
-
strides=(1, 1),
|
337 |
-
padding=((1, 1), (1, 1)),
|
338 |
-
dtype=self.dtype,
|
339 |
-
)
|
340 |
-
|
341 |
-
def __call__(self, pixel_values, deterministic: bool = True):
|
342 |
-
# timestep embedding
|
343 |
-
temb = None
|
344 |
-
|
345 |
-
# downsampling
|
346 |
-
hidden_states = self.conv_in(pixel_values)
|
347 |
-
for block in self.down:
|
348 |
-
hidden_states = block(hidden_states, temb, deterministic=deterministic)
|
349 |
-
|
350 |
-
# middle
|
351 |
-
hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
|
352 |
-
|
353 |
-
# end
|
354 |
-
hidden_states = self.norm_out(hidden_states)
|
355 |
-
hidden_states = nn.swish(hidden_states)
|
356 |
-
hidden_states = self.conv_out(hidden_states)
|
357 |
-
|
358 |
-
return hidden_states
|
359 |
-
|
360 |
-
|
361 |
-
class Decoder(nn.Module):
|
362 |
-
config: VQGANConfig
|
363 |
-
dtype: jnp.dtype = jnp.float32
|
364 |
-
|
365 |
-
def setup(self):
|
366 |
-
self.temb_ch = 0
|
367 |
-
|
368 |
-
# compute in_ch_mult, block_in and curr_res at lowest res
|
369 |
-
block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions - 1]
|
370 |
-
curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
|
371 |
-
self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
|
372 |
-
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
373 |
-
|
374 |
-
# z to block_in
|
375 |
-
self.conv_in = nn.Conv(
|
376 |
-
block_in,
|
377 |
-
kernel_size=(3, 3),
|
378 |
-
strides=(1, 1),
|
379 |
-
padding=((1, 1), (1, 1)),
|
380 |
-
dtype=self.dtype,
|
381 |
-
)
|
382 |
-
|
383 |
-
# middle
|
384 |
-
self.mid = MidBlock(block_in, self.temb_ch, self.config.dropout, dtype=self.dtype)
|
385 |
-
|
386 |
-
# upsampling
|
387 |
-
upsample_blocks = []
|
388 |
-
for i_level in reversed(range(self.config.num_resolutions)):
|
389 |
-
upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
|
390 |
-
if i_level != 0:
|
391 |
-
curr_res = curr_res * 2
|
392 |
-
self.up = list(reversed(upsample_blocks)) # reverse to get consistent order
|
393 |
-
|
394 |
-
# end
|
395 |
-
self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
396 |
-
self.conv_out = nn.Conv(
|
397 |
-
self.config.out_ch,
|
398 |
-
kernel_size=(3, 3),
|
399 |
-
strides=(1, 1),
|
400 |
-
padding=((1, 1), (1, 1)),
|
401 |
-
dtype=self.dtype,
|
402 |
-
)
|
403 |
-
|
404 |
-
def __call__(self, hidden_states, deterministic: bool = True):
|
405 |
-
# timestep embedding
|
406 |
-
temb = None
|
407 |
-
|
408 |
-
# z to block_in
|
409 |
-
hidden_states = self.conv_in(hidden_states)
|
410 |
-
|
411 |
-
# middle
|
412 |
-
hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
|
413 |
-
|
414 |
-
# upsampling
|
415 |
-
for block in reversed(self.up):
|
416 |
-
hidden_states = block(hidden_states, temb, deterministic=deterministic)
|
417 |
-
|
418 |
-
# end
|
419 |
-
if self.config.give_pre_end:
|
420 |
-
return hidden_states
|
421 |
-
|
422 |
-
hidden_states = self.norm_out(hidden_states)
|
423 |
-
hidden_states = nn.swish(hidden_states)
|
424 |
-
hidden_states = self.conv_out(hidden_states)
|
425 |
-
|
426 |
-
return hidden_states
|
427 |
-
|
428 |
-
|
429 |
-
class VectorQuantizer(nn.Module):
|
430 |
-
"""
|
431 |
-
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
432 |
-
____________________________________________
|
433 |
-
Discretization bottleneck part of the VQ-VAE.
|
434 |
-
Inputs:
|
435 |
-
- n_e : number of embeddings
|
436 |
-
- e_dim : dimension of embedding
|
437 |
-
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
438 |
-
_____________________________________________
|
439 |
-
"""
|
440 |
-
|
441 |
-
config: VQGANConfig
|
442 |
-
dtype: jnp.dtype = jnp.float32
|
443 |
-
|
444 |
-
def setup(self):
|
445 |
-
self.embedding = nn.Embed(self.config.n_embed, self.config.embed_dim, dtype=self.dtype) # TODO: init
|
446 |
-
|
447 |
-
def __call__(self, hidden_states):
|
448 |
-
"""
|
449 |
-
Inputs the output of the encoder network z and maps it to a discrete
|
450 |
-
one-hot vector that is the index of the closest embedding vector e_j
|
451 |
-
z (continuous) -> z_q (discrete)
|
452 |
-
z.shape = (batch, channel, height, width)
|
453 |
-
quantization pipeline:
|
454 |
-
1. get encoder input (B,C,H,W)
|
455 |
-
2. flatten input to (B*H*W,C)
|
456 |
-
"""
|
457 |
-
# flatten
|
458 |
-
hidden_states_flattended = hidden_states.reshape((-1, self.config.embed_dim))
|
459 |
-
|
460 |
-
# dummy op to init the weights, so we can access them below
|
461 |
-
self.embedding(jnp.ones((1, 1), dtype="i4"))
|
462 |
-
|
463 |
-
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
464 |
-
emb_weights = self.variables["params"]["embedding"]["embedding"]
|
465 |
-
distance = (
|
466 |
-
jnp.sum(hidden_states_flattended ** 2, axis=1, keepdims=True)
|
467 |
-
+ jnp.sum(emb_weights ** 2, axis=1)
|
468 |
-
- 2 * jnp.dot(hidden_states_flattended, emb_weights.T)
|
469 |
-
)
|
470 |
-
|
471 |
-
# get quantized latent vectors
|
472 |
-
min_encoding_indices = jnp.argmin(distance, axis=1)
|
473 |
-
z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape)
|
474 |
-
|
475 |
-
# reshape to (batch, num_tokens)
|
476 |
-
min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
|
477 |
-
|
478 |
-
# compute the codebook_loss (q_loss) outside the model
|
479 |
-
# here we return the embeddings and indices
|
480 |
-
return z_q, min_encoding_indices
|
481 |
-
|
482 |
-
def get_codebook_entry(self, indices, shape=None):
|
483 |
-
# indices are expected to be of shape (batch, num_tokens)
|
484 |
-
# get quantized latent vectors
|
485 |
-
batch, num_tokens = indices.shape
|
486 |
-
z_q = self.embedding(indices)
|
487 |
-
z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1)
|
488 |
-
return z_q
|
489 |
-
|
490 |
-
|
491 |
-
class VQModule(nn.Module):
|
492 |
-
config: VQGANConfig
|
493 |
-
dtype: jnp.dtype = jnp.float32
|
494 |
-
|
495 |
-
def setup(self):
|
496 |
-
self.encoder = Encoder(self.config, dtype=self.dtype)
|
497 |
-
self.decoder = Decoder(self.config, dtype=self.dtype)
|
498 |
-
self.quantize = VectorQuantizer(self.config, dtype=self.dtype)
|
499 |
-
self.quant_conv = nn.Conv(
|
500 |
-
self.config.embed_dim,
|
501 |
-
kernel_size=(1, 1),
|
502 |
-
strides=(1, 1),
|
503 |
-
padding="VALID",
|
504 |
-
dtype=self.dtype,
|
505 |
-
)
|
506 |
-
self.post_quant_conv = nn.Conv(
|
507 |
-
self.config.z_channels,
|
508 |
-
kernel_size=(1, 1),
|
509 |
-
strides=(1, 1),
|
510 |
-
padding="VALID",
|
511 |
-
dtype=self.dtype,
|
512 |
-
)
|
513 |
-
|
514 |
-
def encode(self, pixel_values, deterministic: bool = True):
|
515 |
-
hidden_states = self.encoder(pixel_values, deterministic=deterministic)
|
516 |
-
hidden_states = self.quant_conv(hidden_states)
|
517 |
-
quant_states, indices = self.quantize(hidden_states)
|
518 |
-
return quant_states, indices
|
519 |
-
|
520 |
-
def decode(self, hidden_states, deterministic: bool = True):
|
521 |
-
hidden_states = self.post_quant_conv(hidden_states)
|
522 |
-
hidden_states = self.decoder(hidden_states, deterministic=deterministic)
|
523 |
-
return hidden_states
|
524 |
-
|
525 |
-
def decode_code(self, code_b):
|
526 |
-
hidden_states = self.quantize.get_codebook_entry(code_b)
|
527 |
-
hidden_states = self.decode(hidden_states)
|
528 |
-
return hidden_states
|
529 |
-
|
530 |
-
def __call__(self, pixel_values, deterministic: bool = True):
|
531 |
-
quant_states, indices = self.encode(pixel_values, deterministic)
|
532 |
-
hidden_states = self.decode(quant_states, deterministic)
|
533 |
-
return hidden_states, indices
|
534 |
-
|
535 |
-
|
536 |
-
class VQGANPreTrainedModel(FlaxPreTrainedModel):
|
537 |
-
"""
|
538 |
-
An abstract class to handle weights initialization and a simple interface
|
539 |
-
for downloading and loading pretrained models.
|
540 |
-
"""
|
541 |
-
|
542 |
-
config_class = VQGANConfig
|
543 |
-
base_model_prefix = "model"
|
544 |
-
module_class: nn.Module = None
|
545 |
-
|
546 |
-
def __init__(
|
547 |
-
self,
|
548 |
-
config: VQGANConfig,
|
549 |
-
input_shape: Tuple = (1, 256, 256, 3),
|
550 |
-
seed: int = 0,
|
551 |
-
dtype: jnp.dtype = jnp.float32,
|
552 |
-
**kwargs,
|
553 |
-
):
|
554 |
-
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
555 |
-
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
556 |
-
|
557 |
-
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
558 |
-
# init input tensors
|
559 |
-
pixel_values = jnp.zeros(input_shape, dtype=jnp.float32)
|
560 |
-
params_rng, dropout_rng = jax.random.split(rng)
|
561 |
-
rngs = {"params": params_rng, "dropout": dropout_rng}
|
562 |
-
|
563 |
-
return self.module.init(rngs, pixel_values)["params"]
|
564 |
-
|
565 |
-
def encode(self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
|
566 |
-
# Handle any PRNG if needed
|
567 |
-
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
568 |
-
|
569 |
-
return self.module.apply(
|
570 |
-
{"params": params or self.params}, jnp.array(pixel_values), not train, rngs=rngs, method=self.module.encode
|
571 |
-
)
|
572 |
-
|
573 |
-
def decode(self, hidden_states, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
|
574 |
-
# Handle any PRNG if needed
|
575 |
-
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
576 |
-
|
577 |
-
return self.module.apply(
|
578 |
-
{"params": params or self.params},
|
579 |
-
jnp.array(hidden_states),
|
580 |
-
not train,
|
581 |
-
rngs=rngs,
|
582 |
-
method=self.module.decode,
|
583 |
-
)
|
584 |
-
|
585 |
-
def decode_code(self, indices, params: dict = None):
|
586 |
-
return self.module.apply(
|
587 |
-
{"params": params or self.params}, jnp.array(indices, dtype="i4"), method=self.module.decode_code
|
588 |
-
)
|
589 |
-
|
590 |
-
def __call__(
|
591 |
-
self,
|
592 |
-
pixel_values,
|
593 |
-
params: dict = None,
|
594 |
-
dropout_rng: jax.random.PRNGKey = None,
|
595 |
-
train: bool = False,
|
596 |
-
):
|
597 |
-
# Handle any PRNG if needed
|
598 |
-
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
599 |
-
|
600 |
-
return self.module.apply(
|
601 |
-
{"params": params or self.params},
|
602 |
-
jnp.array(pixel_values),
|
603 |
-
not train,
|
604 |
-
rngs=rngs,
|
605 |
-
)
|
606 |
-
|
607 |
-
|
608 |
-
class VQModel(VQGANPreTrainedModel):
|
609 |
-
module_class = VQModule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/CC12M_downloader.py
DELETED
@@ -1,91 +0,0 @@
|
|
1 |
-
# Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
|
2 |
-
|
3 |
-
#%%
|
4 |
-
import sys
|
5 |
-
import os
|
6 |
-
from datetime import datetime
|
7 |
-
import pandas as pd
|
8 |
-
import contexttimer
|
9 |
-
from urllib.request import urlopen
|
10 |
-
import requests
|
11 |
-
from PIL import Image
|
12 |
-
import torch
|
13 |
-
from torchvision.transforms import functional as TF
|
14 |
-
from multiprocessing import Pool
|
15 |
-
from tqdm import tqdm
|
16 |
-
import logging
|
17 |
-
|
18 |
-
# Setup
|
19 |
-
logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
|
20 |
-
requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
|
21 |
-
|
22 |
-
|
23 |
-
# # For downloading SVG images (I can't get this to work)
|
24 |
-
# from io import BytesIO
|
25 |
-
# import cairosvg
|
26 |
-
|
27 |
-
#%%
|
28 |
-
# Load data
|
29 |
-
print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
|
30 |
-
with contexttimer.Timer(prefix="Loading from tsv"):
|
31 |
-
df = pd.read_csv('./cc12m.tsv', delimiter='\t', header=None)
|
32 |
-
|
33 |
-
url_to_idx_map = {url: index for index, url, caption in df.itertuples()}
|
34 |
-
print(f'Loaded {len(url_to_idx_map)} urls')
|
35 |
-
|
36 |
-
#%%
|
37 |
-
df.head()
|
38 |
-
|
39 |
-
#%%
|
40 |
-
|
41 |
-
# Note: it seems that there are no SVG images
|
42 |
-
df.sample(10000)[1].str.contains('.svg').sum()
|
43 |
-
|
44 |
-
#%%
|
45 |
-
# Resize function
|
46 |
-
def resize(img):
|
47 |
-
max_size_of_short_side = 512
|
48 |
-
if min(img.size) > max_size_of_short_side:
|
49 |
-
img = TF.resize(img, size=max_size_of_short_side, interpolation=Image.LANCZOS)
|
50 |
-
return img
|
51 |
-
|
52 |
-
base_dir = os.path.join(os.getcwd(), 'images')
|
53 |
-
|
54 |
-
def process(item):
|
55 |
-
url, image_id = item
|
56 |
-
try:
|
57 |
-
base_url = os.path.basename(url) # extract base url
|
58 |
-
stem, ext = os.path.splitext(base_url) # split into stem and extension
|
59 |
-
filename = f'{image_id:08d}---{stem}.jpg' # create filename
|
60 |
-
filepath = os.path.join(base_dir, filename) # concat to get filepath
|
61 |
-
if not os.path.isfile(filepath):
|
62 |
-
# if filepath.endswith('.svg'):
|
63 |
-
# raise NotImplementedError()
|
64 |
-
# image_bytes = BytesIO() # create a bytestream
|
65 |
-
# cairosvg.svg2png(url=url, write_to=image_bytes) # convert svg into image
|
66 |
-
# else:
|
67 |
-
req = requests.get(url, stream=True, timeout=1, verify=False).raw
|
68 |
-
image = Image.open(req).convert('RGB')
|
69 |
-
if min(image.size) > 512:
|
70 |
-
image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
|
71 |
-
# image = resize(image) # resize PIL image
|
72 |
-
image.save(filepath) # save PIL image
|
73 |
-
except Exception as e:
|
74 |
-
logging.info(" ".join(repr(e).splitlines()))
|
75 |
-
logging.error(url)
|
76 |
-
|
77 |
-
#%%
|
78 |
-
#for i, item in enumerate(tqdm(url_to_idx_map.items(), total=len(url_to_idx_map))):
|
79 |
-
# process(item)
|
80 |
-
# if i > 100:
|
81 |
-
# break
|
82 |
-
|
83 |
-
# Use multiprocessing for speed
|
84 |
-
list_of_items = list(url_to_idx_map.items())
|
85 |
-
print(len(list_of_items))
|
86 |
-
list_of_items = list_of_items[10_000_000:]
|
87 |
-
print(len(list_of_items))
|
88 |
-
with Pool(128) as p:
|
89 |
-
r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
|
90 |
-
print('DONE')
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/CC3M_downloader.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
'''
|
2 |
-
This script was adapted from Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
|
3 |
-
Few changes were made for the particular dataset. You're required to have the `.tsv` file downloaded in your directory.
|
4 |
-
Find them here- [https://github.com/google-research-datasets/conceptual-captions]
|
5 |
-
'''
|
6 |
-
|
7 |
-
import sys
|
8 |
-
import os
|
9 |
-
from datetime import datetime
|
10 |
-
import pandas as pd
|
11 |
-
import contexttimer
|
12 |
-
from urllib.request import urlopen
|
13 |
-
import requests
|
14 |
-
from PIL import Image
|
15 |
-
import torch
|
16 |
-
from torchvision.transforms import functional as TF
|
17 |
-
from multiprocessing import Pool
|
18 |
-
from tqdm import tqdm
|
19 |
-
import logging
|
20 |
-
import sys
|
21 |
-
|
22 |
-
# Setup
|
23 |
-
logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
|
24 |
-
requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
|
25 |
-
|
26 |
-
if len(sys.argv) != 3:
|
27 |
-
print("Provide .tsv file name & output directory. e.g. python downloader.py Train-GCC-training.tsv training")
|
28 |
-
exit(1)
|
29 |
-
|
30 |
-
# Load data
|
31 |
-
print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
|
32 |
-
with contexttimer.Timer(prefix="Loading from tsv"):
|
33 |
-
df = pd.read_csv(sys.argv[1], delimiter='\t', header=None)
|
34 |
-
|
35 |
-
url_to_idx_map = {url: index for index, caption, url in df.itertuples()}
|
36 |
-
print(f'Loaded {len(url_to_idx_map)} urls')
|
37 |
-
|
38 |
-
base_dir = os.path.join(os.getcwd(), sys.argv[2])
|
39 |
-
|
40 |
-
def process(item):
|
41 |
-
url, image_id = item
|
42 |
-
try:
|
43 |
-
base_url = os.path.basename(url) # extract base url
|
44 |
-
stem, ext = os.path.splitext(base_url) # split into stem and extension
|
45 |
-
filename = f'{image_id:08d}---{stem}.jpg' # create filename
|
46 |
-
filepath = os.path.join(base_dir, filename) # concat to get filepath
|
47 |
-
if not os.path.isfile(filepath):
|
48 |
-
req = requests.get(url, stream=True, timeout=1, verify=False).raw
|
49 |
-
image = Image.open(req).convert('RGB')
|
50 |
-
if min(image.size) > 512:
|
51 |
-
image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
|
52 |
-
image.save(filepath) # save PIL image
|
53 |
-
except Exception as e:
|
54 |
-
logging.info(" ".join(repr(e).splitlines()))
|
55 |
-
logging.error(url)
|
56 |
-
|
57 |
-
list_of_items = list(url_to_idx_map.items())
|
58 |
-
print(len(list_of_items))
|
59 |
-
|
60 |
-
with Pool(128) as p:
|
61 |
-
r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
|
62 |
-
print('DONE')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/CustomBARTv4b_model-generate.ipynb
DELETED
@@ -1,566 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"nbformat": 4,
|
3 |
-
"nbformat_minor": 0,
|
4 |
-
"metadata": {
|
5 |
-
"colab": {
|
6 |
-
"name": "CustomBARTv4b-model-generate.ipynb",
|
7 |
-
"provenance": [],
|
8 |
-
"collapsed_sections": [],
|
9 |
-
"machine_shape": "hm"
|
10 |
-
},
|
11 |
-
"kernelspec": {
|
12 |
-
"name": "python3",
|
13 |
-
"display_name": "Python 3"
|
14 |
-
},
|
15 |
-
"language_info": {
|
16 |
-
"name": "python"
|
17 |
-
},
|
18 |
-
"accelerator": "TPU"
|
19 |
-
},
|
20 |
-
"cells": [
|
21 |
-
{
|
22 |
-
"cell_type": "markdown",
|
23 |
-
"metadata": {
|
24 |
-
"id": "ewer-Q-0w2xA"
|
25 |
-
},
|
26 |
-
"source": [
|
27 |
-
"# Installation"
|
28 |
-
]
|
29 |
-
},
|
30 |
-
{
|
31 |
-
"cell_type": "code",
|
32 |
-
"metadata": {
|
33 |
-
"colab": {
|
34 |
-
"base_uri": "https://localhost:8080/"
|
35 |
-
},
|
36 |
-
"id": "NpsF9ipLLl2s",
|
37 |
-
"outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
|
38 |
-
},
|
39 |
-
"source": [
|
40 |
-
"!pip install git+https://github.com/huggingface/transformers/\n",
|
41 |
-
"!pip install git+https://github.com/google/flax"
|
42 |
-
],
|
43 |
-
"execution_count": 1,
|
44 |
-
"outputs": [
|
45 |
-
{
|
46 |
-
"output_type": "stream",
|
47 |
-
"text": [
|
48 |
-
"Collecting git+https://github.com/huggingface/transformers/\n",
|
49 |
-
" Cloning https://github.com/huggingface/transformers/ to /tmp/pip-req-build-oxejx1op\n",
|
50 |
-
" Running command git clone -q https://github.com/huggingface/transformers/ /tmp/pip-req-build-oxejx1op\n",
|
51 |
-
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
|
52 |
-
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
|
53 |
-
" Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
|
54 |
-
"Requirement already satisfied (use --upgrade to upgrade): transformers==4.9.0.dev0 from git+https://github.com/huggingface/transformers/ in /usr/local/lib/python3.7/dist-packages\n",
|
55 |
-
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (1.19.5)\n",
|
56 |
-
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (20.9)\n",
|
57 |
-
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (5.4.1)\n",
|
58 |
-
"Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.0.45)\n",
|
59 |
-
"Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (4.6.0)\n",
|
60 |
-
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (4.41.1)\n",
|
61 |
-
"Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (3.0.12)\n",
|
62 |
-
"Requirement already satisfied: huggingface-hub==0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.0.12)\n",
|
63 |
-
"Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.10.3)\n",
|
64 |
-
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (2019.12.20)\n",
|
65 |
-
"Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (2.23.0)\n",
|
66 |
-
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers==4.9.0.dev0) (2.4.7)\n",
|
67 |
-
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (1.15.0)\n",
|
68 |
-
"Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (1.0.1)\n",
|
69 |
-
"Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (7.1.2)\n",
|
70 |
-
"Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers==4.9.0.dev0) (3.7.4.3)\n",
|
71 |
-
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers==4.9.0.dev0) (3.4.1)\n",
|
72 |
-
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (2021.5.30)\n",
|
73 |
-
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (3.0.4)\n",
|
74 |
-
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (1.24.3)\n",
|
75 |
-
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (2.10)\n",
|
76 |
-
"Building wheels for collected packages: transformers\n",
|
77 |
-
" Building wheel for transformers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
|
78 |
-
" Created wheel for transformers: filename=transformers-4.9.0.dev0-cp37-none-any.whl size=2582229 sha256=249c593273ccca3027c6427d2c6fd749a89f21d722d628d97eb438a2cf3185a8\n",
|
79 |
-
" Stored in directory: /tmp/pip-ephem-wheel-cache-l2rqt1b7/wheels/61/69/33/974fccec4d0ab5feee9fe83bd93e680d269a805be9ede5ec60\n",
|
80 |
-
"Successfully built transformers\n",
|
81 |
-
"Collecting git+https://github.com/google/flax\n",
|
82 |
-
" Cloning https://github.com/google/flax to /tmp/pip-req-build-rt9g1_wx\n",
|
83 |
-
" Running command git clone -q https://github.com/google/flax /tmp/pip-req-build-rt9g1_wx\n",
|
84 |
-
"Requirement already satisfied (use --upgrade to upgrade): flax==0.3.4 from git+https://github.com/google/flax in /usr/local/lib/python3.7/dist-packages\n",
|
85 |
-
"Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (1.19.5)\n",
|
86 |
-
"Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (0.2.13)\n",
|
87 |
-
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (3.2.2)\n",
|
88 |
-
"Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (1.0.2)\n",
|
89 |
-
"Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (0.0.9)\n",
|
90 |
-
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->flax==0.3.4) (3.3.0)\n",
|
91 |
-
"Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->flax==0.3.4) (0.12.0)\n",
|
92 |
-
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (2.8.1)\n",
|
93 |
-
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (0.10.0)\n",
|
94 |
-
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (2.4.7)\n",
|
95 |
-
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (1.3.1)\n",
|
96 |
-
"Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax->flax==0.3.4) (0.0.8)\n",
|
97 |
-
"Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax->flax==0.3.4) (0.1.66+cuda110)\n",
|
98 |
-
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax>=0.2.13->flax==0.3.4) (1.15.0)\n",
|
99 |
-
"Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax==0.3.4) (0.1.6)\n",
|
100 |
-
"Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax==0.3.4) (0.11.1)\n",
|
101 |
-
"Requirement already satisfied: flatbuffers in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->flax==0.3.4) (1.12)\n",
|
102 |
-
"Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->flax==0.3.4) (1.4.1)\n",
|
103 |
-
"Building wheels for collected packages: flax\n",
|
104 |
-
" Building wheel for flax (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
105 |
-
" Created wheel for flax: filename=flax-0.3.4-cp37-none-any.whl size=184692 sha256=503b27995f372afe33631e71572d5edc1fffd4d2e0a4cd206d291ad6b0e4c299\n",
|
106 |
-
" Stored in directory: /tmp/pip-ephem-wheel-cache-g1pzxnv6/wheels/3d/26/f4/0ea6051d7352289d9e4f8178348452b35a9a97bde6035405a5\n",
|
107 |
-
"Successfully built flax\n"
|
108 |
-
],
|
109 |
-
"name": "stdout"
|
110 |
-
}
|
111 |
-
]
|
112 |
-
},
|
113 |
-
{
|
114 |
-
"cell_type": "code",
|
115 |
-
"metadata": {
|
116 |
-
"id": "M1wVkrpjU6zO"
|
117 |
-
},
|
118 |
-
"source": [
|
119 |
-
"%load_ext autoreload\n",
|
120 |
-
"%autoreload 2"
|
121 |
-
],
|
122 |
-
"execution_count": 2,
|
123 |
-
"outputs": []
|
124 |
-
},
|
125 |
-
{
|
126 |
-
"cell_type": "markdown",
|
127 |
-
"metadata": {
|
128 |
-
"id": "t47CH1H_IOT8"
|
129 |
-
},
|
130 |
-
"source": [
|
131 |
-
"# Custom BART Model"
|
132 |
-
]
|
133 |
-
},
|
134 |
-
{
|
135 |
-
"cell_type": "code",
|
136 |
-
"metadata": {
|
137 |
-
"id": "9jQnM6S2vCpn"
|
138 |
-
},
|
139 |
-
"source": [
|
140 |
-
"# TODO: set those args in a config file\n",
|
141 |
-
"OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
|
142 |
-
"OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
|
143 |
-
"BOS_TOKEN_ID = 16384\n",
|
144 |
-
"BASE_MODEL = 'facebook/bart-large'"
|
145 |
-
],
|
146 |
-
"execution_count": 3,
|
147 |
-
"outputs": []
|
148 |
-
},
|
149 |
-
{
|
150 |
-
"cell_type": "code",
|
151 |
-
"metadata": {
|
152 |
-
"id": "_eEaJVxAKpV5"
|
153 |
-
},
|
154 |
-
"source": [
|
155 |
-
"import jax\n",
|
156 |
-
"import flax.linen as nn\n",
|
157 |
-
"\n",
|
158 |
-
"from transformers.models.bart.modeling_flax_bart import *\n",
|
159 |
-
"from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
|
160 |
-
"\n",
|
161 |
-
"class CustomFlaxBartModule(FlaxBartModule):\n",
|
162 |
-
" def setup(self):\n",
|
163 |
-
" # we keep shared to easily load pre-trained weights\n",
|
164 |
-
" self.shared = nn.Embed(\n",
|
165 |
-
" self.config.vocab_size,\n",
|
166 |
-
" self.config.d_model,\n",
|
167 |
-
" embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
|
168 |
-
" dtype=self.dtype,\n",
|
169 |
-
" )\n",
|
170 |
-
" # a separate embedding is used for the decoder\n",
|
171 |
-
" self.decoder_embed = nn.Embed(\n",
|
172 |
-
" OUTPUT_VOCAB_SIZE,\n",
|
173 |
-
" self.config.d_model,\n",
|
174 |
-
" embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
|
175 |
-
" dtype=self.dtype,\n",
|
176 |
-
" )\n",
|
177 |
-
" self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
|
178 |
-
"\n",
|
179 |
-
" # the decoder has a different config\n",
|
180 |
-
" decoder_config = BartConfig(self.config.to_dict())\n",
|
181 |
-
" decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
|
182 |
-
" decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
|
183 |
-
" self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
|
184 |
-
"\n",
|
185 |
-
"class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
|
186 |
-
" def setup(self):\n",
|
187 |
-
" self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
|
188 |
-
" self.lm_head = nn.Dense(\n",
|
189 |
-
" OUTPUT_VOCAB_SIZE,\n",
|
190 |
-
" use_bias=False,\n",
|
191 |
-
" dtype=self.dtype,\n",
|
192 |
-
" kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
|
193 |
-
" )\n",
|
194 |
-
" self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
|
195 |
-
"\n",
|
196 |
-
"class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
|
197 |
-
" module_class = CustomFlaxBartForConditionalGenerationModule"
|
198 |
-
],
|
199 |
-
"execution_count": 4,
|
200 |
-
"outputs": []
|
201 |
-
},
|
202 |
-
{
|
203 |
-
"cell_type": "code",
|
204 |
-
"metadata": {
|
205 |
-
"id": "S7CP9Td9m2ge",
|
206 |
-
"colab": {
|
207 |
-
"base_uri": "https://localhost:8080/"
|
208 |
-
},
|
209 |
-
"outputId": "5638ef68-9c40-46f7-90ba-a4d05b61360d"
|
210 |
-
},
|
211 |
-
"source": [
|
212 |
-
"# load pre-trained model for encoder weights\n",
|
213 |
-
"base_model = FlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)"
|
214 |
-
],
|
215 |
-
"execution_count": 5,
|
216 |
-
"outputs": [
|
217 |
-
{
|
218 |
-
"output_type": "stream",
|
219 |
-
"text": [
|
220 |
-
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
|
221 |
-
],
|
222 |
-
"name": "stderr"
|
223 |
-
}
|
224 |
-
]
|
225 |
-
},
|
226 |
-
{
|
227 |
-
"cell_type": "code",
|
228 |
-
"metadata": {
|
229 |
-
"id": "6lmynR-poceH"
|
230 |
-
},
|
231 |
-
"source": [
|
232 |
-
"# set up our new model config\n",
|
233 |
-
"config = BartConfig.from_pretrained(BASE_MODEL)\n",
|
234 |
-
"config.tie_word_embeddings = False\n",
|
235 |
-
"config.decoder_start_token_id = BOS_TOKEN_ID\n",
|
236 |
-
"config.bos_token_id = BOS_TOKEN_ID # should not be used\n",
|
237 |
-
"config.pos_token_id = BOS_TOKEN_ID # should not be used\n",
|
238 |
-
"#config.eos_token_id = None # prevents generation from stopping until we reach max_length"
|
239 |
-
],
|
240 |
-
"execution_count": 6,
|
241 |
-
"outputs": []
|
242 |
-
},
|
243 |
-
{
|
244 |
-
"cell_type": "code",
|
245 |
-
"metadata": {
|
246 |
-
"id": "_6-XKK40oEfP"
|
247 |
-
},
|
248 |
-
"source": [
|
249 |
-
"# create our model and initialize it randomly\n",
|
250 |
-
"model = CustomFlaxBartForConditionalGeneration(config)"
|
251 |
-
],
|
252 |
-
"execution_count": 7,
|
253 |
-
"outputs": []
|
254 |
-
},
|
255 |
-
{
|
256 |
-
"cell_type": "code",
|
257 |
-
"metadata": {
|
258 |
-
"id": "-r_hZestr-NR"
|
259 |
-
},
|
260 |
-
"source": [
|
261 |
-
"# use pretrained weights\n",
|
262 |
-
"model.params['model']['encoder'] = base_model.params['model']['encoder']\n",
|
263 |
-
"model.params['model']['shared'] = base_model.params['model']['shared']"
|
264 |
-
],
|
265 |
-
"execution_count": 8,
|
266 |
-
"outputs": []
|
267 |
-
},
|
268 |
-
{
|
269 |
-
"cell_type": "code",
|
270 |
-
"metadata": {
|
271 |
-
"id": "5NEX8f62sVjx"
|
272 |
-
},
|
273 |
-
"source": [
|
274 |
-
"# no need for base_model anymore\n",
|
275 |
-
"del base_model"
|
276 |
-
],
|
277 |
-
"execution_count": 9,
|
278 |
-
"outputs": []
|
279 |
-
},
|
280 |
-
{
|
281 |
-
"cell_type": "code",
|
282 |
-
"metadata": {
|
283 |
-
"colab": {
|
284 |
-
"base_uri": "https://localhost:8080/"
|
285 |
-
},
|
286 |
-
"id": "Jz032w73nHEf",
|
287 |
-
"outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
|
288 |
-
},
|
289 |
-
"source": [
|
290 |
-
"# we verify that the shape has not been modified\n",
|
291 |
-
"model.params['final_logits_bias'].shape"
|
292 |
-
],
|
293 |
-
"execution_count": 10,
|
294 |
-
"outputs": [
|
295 |
-
{
|
296 |
-
"output_type": "execute_result",
|
297 |
-
"data": {
|
298 |
-
"text/plain": [
|
299 |
-
"(1, 16385)"
|
300 |
-
]
|
301 |
-
},
|
302 |
-
"metadata": {
|
303 |
-
"tags": []
|
304 |
-
},
|
305 |
-
"execution_count": 10
|
306 |
-
}
|
307 |
-
]
|
308 |
-
},
|
309 |
-
{
|
310 |
-
"cell_type": "markdown",
|
311 |
-
"metadata": {
|
312 |
-
"id": "zLl24Ez5t7x1"
|
313 |
-
},
|
314 |
-
"source": [
|
315 |
-
"## Inference"
|
316 |
-
]
|
317 |
-
},
|
318 |
-
{
|
319 |
-
"cell_type": "code",
|
320 |
-
"metadata": {
|
321 |
-
"id": "XLLA2NK3uDQr"
|
322 |
-
},
|
323 |
-
"source": [
|
324 |
-
"tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
|
325 |
-
],
|
326 |
-
"execution_count": 11,
|
327 |
-
"outputs": []
|
328 |
-
},
|
329 |
-
{
|
330 |
-
"cell_type": "code",
|
331 |
-
"metadata": {
|
332 |
-
"colab": {
|
333 |
-
"base_uri": "https://localhost:8080/"
|
334 |
-
},
|
335 |
-
"id": "Ntow53I_t81D",
|
336 |
-
"outputId": "59289cdd-1429-4720-cc87-88810c4b99ac"
|
337 |
-
},
|
338 |
-
"source": [
|
339 |
-
"text = \"My friends are cool but they eat too many carbs.\"\n",
|
340 |
-
"inputs = tokenizer(text, max_length=1024, return_tensors='jax')\n",
|
341 |
-
"encoder_outputs = model.encode(**inputs)"
|
342 |
-
],
|
343 |
-
"execution_count": 12,
|
344 |
-
"outputs": [
|
345 |
-
{
|
346 |
-
"output_type": "stream",
|
347 |
-
"text": [
|
348 |
-
"Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n"
|
349 |
-
],
|
350 |
-
"name": "stderr"
|
351 |
-
}
|
352 |
-
]
|
353 |
-
},
|
354 |
-
{
|
355 |
-
"cell_type": "code",
|
356 |
-
"metadata": {
|
357 |
-
"colab": {
|
358 |
-
"base_uri": "https://localhost:8080/"
|
359 |
-
},
|
360 |
-
"id": "vcRNJnJ_uJOJ",
|
361 |
-
"outputId": "025afd54-7908-4a9c-fb59-e40bd3458711"
|
362 |
-
},
|
363 |
-
"source": [
|
364 |
-
"decoder_start_token_id = model.config.decoder_start_token_id\n",
|
365 |
-
"decoder_start_token_id"
|
366 |
-
],
|
367 |
-
"execution_count": 13,
|
368 |
-
"outputs": [
|
369 |
-
{
|
370 |
-
"output_type": "execute_result",
|
371 |
-
"data": {
|
372 |
-
"text/plain": [
|
373 |
-
"16384"
|
374 |
-
]
|
375 |
-
},
|
376 |
-
"metadata": {
|
377 |
-
"tags": []
|
378 |
-
},
|
379 |
-
"execution_count": 13
|
380 |
-
}
|
381 |
-
]
|
382 |
-
},
|
383 |
-
{
|
384 |
-
"cell_type": "code",
|
385 |
-
"metadata": {
|
386 |
-
"id": "6QWmEwL_uMld"
|
387 |
-
},
|
388 |
-
"source": [
|
389 |
-
"decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype=\"i4\") * decoder_start_token_id\n",
|
390 |
-
"outputs = model.decode(decoder_input_ids, encoder_outputs)"
|
391 |
-
],
|
392 |
-
"execution_count": 14,
|
393 |
-
"outputs": []
|
394 |
-
},
|
395 |
-
{
|
396 |
-
"cell_type": "code",
|
397 |
-
"metadata": {
|
398 |
-
"colab": {
|
399 |
-
"base_uri": "https://localhost:8080/"
|
400 |
-
},
|
401 |
-
"id": "c_ys3yWBothF",
|
402 |
-
"outputId": "40d4d584-e0a8-44cb-bbea-0ffa38d50a53"
|
403 |
-
},
|
404 |
-
"source": [
|
405 |
-
"outputs"
|
406 |
-
],
|
407 |
-
"execution_count": 15,
|
408 |
-
"outputs": [
|
409 |
-
{
|
410 |
-
"output_type": "execute_result",
|
411 |
-
"data": {
|
412 |
-
"text/plain": [
|
413 |
-
"FlaxCausalLMOutputWithCrossAttentions([('logits',\n",
|
414 |
-
" DeviceArray([[[ 0.5263986 , -2.0947676 , -0.18830685, ..., 0.7599884 ,\n",
|
415 |
-
" 0.6746795 , -1.0411576 ]]], dtype=float32))])"
|
416 |
-
]
|
417 |
-
},
|
418 |
-
"metadata": {
|
419 |
-
"tags": []
|
420 |
-
},
|
421 |
-
"execution_count": 15
|
422 |
-
}
|
423 |
-
]
|
424 |
-
},
|
425 |
-
{
|
426 |
-
"cell_type": "code",
|
427 |
-
"metadata": {
|
428 |
-
"colab": {
|
429 |
-
"base_uri": "https://localhost:8080/"
|
430 |
-
},
|
431 |
-
"id": "O6s0wtB_uTC_",
|
432 |
-
"outputId": "bc0e9e80-e346-4e99-d28e-3f658eda1f66"
|
433 |
-
},
|
434 |
-
"source": [
|
435 |
-
"outputs.logits.shape"
|
436 |
-
],
|
437 |
-
"execution_count": 16,
|
438 |
-
"outputs": [
|
439 |
-
{
|
440 |
-
"output_type": "execute_result",
|
441 |
-
"data": {
|
442 |
-
"text/plain": [
|
443 |
-
"(1, 1, 16385)"
|
444 |
-
]
|
445 |
-
},
|
446 |
-
"metadata": {
|
447 |
-
"tags": []
|
448 |
-
},
|
449 |
-
"execution_count": 16
|
450 |
-
}
|
451 |
-
]
|
452 |
-
},
|
453 |
-
{
|
454 |
-
"cell_type": "code",
|
455 |
-
"metadata": {
|
456 |
-
"colab": {
|
457 |
-
"base_uri": "https://localhost:8080/"
|
458 |
-
},
|
459 |
-
"id": "ELzemGP3uBzy",
|
460 |
-
"outputId": "dc12f98a-1ccf-450d-ba2a-9c29d7d14885"
|
461 |
-
},
|
462 |
-
"source": [
|
463 |
-
"outputs.logits.argmax(axis=-1)"
|
464 |
-
],
|
465 |
-
"execution_count": 17,
|
466 |
-
"outputs": [
|
467 |
-
{
|
468 |
-
"output_type": "execute_result",
|
469 |
-
"data": {
|
470 |
-
"text/plain": [
|
471 |
-
"DeviceArray([[12459]], dtype=int32)"
|
472 |
-
]
|
473 |
-
},
|
474 |
-
"metadata": {
|
475 |
-
"tags": []
|
476 |
-
},
|
477 |
-
"execution_count": 17
|
478 |
-
}
|
479 |
-
]
|
480 |
-
},
|
481 |
-
{
|
482 |
-
"cell_type": "code",
|
483 |
-
"metadata": {
|
484 |
-
"colab": {
|
485 |
-
"base_uri": "https://localhost:8080/"
|
486 |
-
},
|
487 |
-
"id": "fQjikkGEunpx",
|
488 |
-
"outputId": "3dba0209-ad4e-4069-be38-6c599c677ef1"
|
489 |
-
},
|
490 |
-
"source": [
|
491 |
-
"model.config.bos_token_id, model.config.eos_token_id, model.config.pad_token_id"
|
492 |
-
],
|
493 |
-
"execution_count": 18,
|
494 |
-
"outputs": [
|
495 |
-
{
|
496 |
-
"output_type": "execute_result",
|
497 |
-
"data": {
|
498 |
-
"text/plain": [
|
499 |
-
"(16384, 2, 1)"
|
500 |
-
]
|
501 |
-
},
|
502 |
-
"metadata": {
|
503 |
-
"tags": []
|
504 |
-
},
|
505 |
-
"execution_count": 18
|
506 |
-
}
|
507 |
-
]
|
508 |
-
},
|
509 |
-
{
|
510 |
-
"cell_type": "code",
|
511 |
-
"metadata": {
|
512 |
-
"id": "P32mJJSbrU1F"
|
513 |
-
},
|
514 |
-
"source": [
|
515 |
-
"input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')"
|
516 |
-
],
|
517 |
-
"execution_count": 19,
|
518 |
-
"outputs": []
|
519 |
-
},
|
520 |
-
{
|
521 |
-
"cell_type": "code",
|
522 |
-
"metadata": {
|
523 |
-
"id": "C7cHbIHruELT"
|
524 |
-
},
|
525 |
-
"source": [
|
526 |
-
"greedy_output = model.generate(input_ids_test, max_length=50)"
|
527 |
-
],
|
528 |
-
"execution_count": 20,
|
529 |
-
"outputs": []
|
530 |
-
},
|
531 |
-
{
|
532 |
-
"cell_type": "code",
|
533 |
-
"metadata": {
|
534 |
-
"colab": {
|
535 |
-
"base_uri": "https://localhost:8080/"
|
536 |
-
},
|
537 |
-
"id": "jYugh9cOuwc9",
|
538 |
-
"outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
|
539 |
-
},
|
540 |
-
"source": [
|
541 |
-
"greedy_output[0]"
|
542 |
-
],
|
543 |
-
"execution_count": 21,
|
544 |
-
"outputs": [
|
545 |
-
{
|
546 |
-
"output_type": "execute_result",
|
547 |
-
"data": {
|
548 |
-
"text/plain": [
|
549 |
-
"DeviceArray([[16384, 0, 3570, 13405, 10186, 2392, 16362, 1869,\n",
|
550 |
-
" 15772, 13546, 15772, 13546, 9348, 14791, 15772, 15772,\n",
|
551 |
-
" 15772, 11272, 15772, 13546, 15772, 15772, 13546, 15772,\n",
|
552 |
-
" 13546, 15772, 6642, 15772, 10776, 6431, 15772, 14567,\n",
|
553 |
-
" 13406, 15772, 14567, 6235, 15772, 4909, 16160, 568,\n",
|
554 |
-
" 4664, 6650, 8952, 9089, 15772, 5952, 7375, 10843,\n",
|
555 |
-
" 8952, 2]], dtype=int32)"
|
556 |
-
]
|
557 |
-
},
|
558 |
-
"metadata": {
|
559 |
-
"tags": []
|
560 |
-
},
|
561 |
-
"execution_count": 21
|
562 |
-
}
|
563 |
-
]
|
564 |
-
}
|
565 |
-
]
|
566 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/demo_notebook.ipynb
DELETED
@@ -1,583 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"metadata": {
|
6 |
-
"id": "ewer-Q-0w2xA"
|
7 |
-
},
|
8 |
-
"source": [
|
9 |
-
"# Installation"
|
10 |
-
]
|
11 |
-
},
|
12 |
-
{
|
13 |
-
"cell_type": "code",
|
14 |
-
"execution_count": null,
|
15 |
-
"metadata": {
|
16 |
-
"colab": {
|
17 |
-
"base_uri": "https://localhost:8080/"
|
18 |
-
},
|
19 |
-
"id": "NpsF9ipLLl2s",
|
20 |
-
"outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
|
21 |
-
},
|
22 |
-
"outputs": [],
|
23 |
-
"source": [
|
24 |
-
"#!pip install git+https://github.com/huggingface/transformers/\n",
|
25 |
-
"#!pip install git+https://github.com/google/flax"
|
26 |
-
]
|
27 |
-
},
|
28 |
-
{
|
29 |
-
"cell_type": "code",
|
30 |
-
"execution_count": 1,
|
31 |
-
"metadata": {
|
32 |
-
"id": "M1wVkrpjU6zO"
|
33 |
-
},
|
34 |
-
"outputs": [],
|
35 |
-
"source": [
|
36 |
-
"%load_ext autoreload\n",
|
37 |
-
"%autoreload 2"
|
38 |
-
]
|
39 |
-
},
|
40 |
-
{
|
41 |
-
"cell_type": "code",
|
42 |
-
"execution_count": 2,
|
43 |
-
"metadata": {},
|
44 |
-
"outputs": [
|
45 |
-
{
|
46 |
-
"name": "stdout",
|
47 |
-
"output_type": "stream",
|
48 |
-
"text": [
|
49 |
-
"/home/tmabraham/vqgan-jax\n"
|
50 |
-
]
|
51 |
-
}
|
52 |
-
],
|
53 |
-
"source": [
|
54 |
-
"%cd ../../vqgan-jax"
|
55 |
-
]
|
56 |
-
},
|
57 |
-
{
|
58 |
-
"cell_type": "markdown",
|
59 |
-
"metadata": {
|
60 |
-
"id": "t47CH1H_IOT8"
|
61 |
-
},
|
62 |
-
"source": [
|
63 |
-
"# Custom BART Model"
|
64 |
-
]
|
65 |
-
},
|
66 |
-
{
|
67 |
-
"cell_type": "code",
|
68 |
-
"execution_count": 3,
|
69 |
-
"metadata": {
|
70 |
-
"id": "9jQnM6S2vCpn"
|
71 |
-
},
|
72 |
-
"outputs": [],
|
73 |
-
"source": [
|
74 |
-
"# TODO: set those args in a config file\n",
|
75 |
-
"OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
|
76 |
-
"OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
|
77 |
-
"BOS_TOKEN_ID = 16384\n",
|
78 |
-
"BASE_MODEL = 'facebook/bart-large'"
|
79 |
-
]
|
80 |
-
},
|
81 |
-
{
|
82 |
-
"cell_type": "code",
|
83 |
-
"execution_count": 4,
|
84 |
-
"metadata": {
|
85 |
-
"id": "_eEaJVxAKpV5"
|
86 |
-
},
|
87 |
-
"outputs": [],
|
88 |
-
"source": [
|
89 |
-
"import jax\n",
|
90 |
-
"import flax.linen as nn\n",
|
91 |
-
"\n",
|
92 |
-
"from transformers.models.bart.modeling_flax_bart import *\n",
|
93 |
-
"from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
|
94 |
-
"\n",
|
95 |
-
"class CustomFlaxBartModule(FlaxBartModule):\n",
|
96 |
-
" def setup(self):\n",
|
97 |
-
" # we keep shared to easily load pre-trained weights\n",
|
98 |
-
" self.shared = nn.Embed(\n",
|
99 |
-
" self.config.vocab_size,\n",
|
100 |
-
" self.config.d_model,\n",
|
101 |
-
" embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
|
102 |
-
" dtype=self.dtype,\n",
|
103 |
-
" )\n",
|
104 |
-
" # a separate embedding is used for the decoder\n",
|
105 |
-
" self.decoder_embed = nn.Embed(\n",
|
106 |
-
" OUTPUT_VOCAB_SIZE,\n",
|
107 |
-
" self.config.d_model,\n",
|
108 |
-
" embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
|
109 |
-
" dtype=self.dtype,\n",
|
110 |
-
" )\n",
|
111 |
-
" self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
|
112 |
-
"\n",
|
113 |
-
" # the decoder has a different config\n",
|
114 |
-
" decoder_config = BartConfig(self.config.to_dict())\n",
|
115 |
-
" decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
|
116 |
-
" decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
|
117 |
-
" self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
|
118 |
-
"\n",
|
119 |
-
"class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
|
120 |
-
" def setup(self):\n",
|
121 |
-
" self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
|
122 |
-
" self.lm_head = nn.Dense(\n",
|
123 |
-
" OUTPUT_VOCAB_SIZE,\n",
|
124 |
-
" use_bias=False,\n",
|
125 |
-
" dtype=self.dtype,\n",
|
126 |
-
" kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
|
127 |
-
" )\n",
|
128 |
-
" self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
|
129 |
-
"\n",
|
130 |
-
"class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
|
131 |
-
" module_class = CustomFlaxBartForConditionalGenerationModule"
|
132 |
-
]
|
133 |
-
},
|
134 |
-
{
|
135 |
-
"cell_type": "code",
|
136 |
-
"execution_count": 5,
|
137 |
-
"metadata": {
|
138 |
-
"scrolled": true
|
139 |
-
},
|
140 |
-
"outputs": [
|
141 |
-
{
|
142 |
-
"name": "stderr",
|
143 |
-
"output_type": "stream",
|
144 |
-
"text": [
|
145 |
-
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mtmabraham\u001b[0m (use `wandb login --relogin` to force relogin)\n"
|
146 |
-
]
|
147 |
-
},
|
148 |
-
{
|
149 |
-
"data": {
|
150 |
-
"text/html": [
|
151 |
-
"\n",
|
152 |
-
" Tracking run with wandb version 0.10.33<br/>\n",
|
153 |
-
" Syncing run <strong style=\"color:#cdcd00\">rare-night-7</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
|
154 |
-
" Project page: <a href=\"https://wandb.ai/tmabraham/vqgan-jax\" target=\"_blank\">https://wandb.ai/tmabraham/vqgan-jax</a><br/>\n",
|
155 |
-
" Run page: <a href=\"https://wandb.ai/tmabraham/vqgan-jax/runs/qzxavce8\" target=\"_blank\">https://wandb.ai/tmabraham/vqgan-jax/runs/qzxavce8</a><br/>\n",
|
156 |
-
" Run data is saved locally in <code>/home/tmabraham/vqgan-jax/wandb/run-20210715_075019-qzxavce8</code><br/><br/>\n",
|
157 |
-
" "
|
158 |
-
],
|
159 |
-
"text/plain": [
|
160 |
-
"<IPython.core.display.HTML object>"
|
161 |
-
]
|
162 |
-
},
|
163 |
-
"metadata": {},
|
164 |
-
"output_type": "display_data"
|
165 |
-
},
|
166 |
-
{
|
167 |
-
"name": "stderr",
|
168 |
-
"output_type": "stream",
|
169 |
-
"text": [
|
170 |
-
"\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-1ef8yxby:latest, 1674.97MB. 2 files... Done. 0:0:0\n"
|
171 |
-
]
|
172 |
-
}
|
173 |
-
],
|
174 |
-
"source": [
|
175 |
-
"import wandb\n",
|
176 |
-
"run = wandb.init()\n",
|
177 |
-
"artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-1ef8yxby:latest', type='bart_model')\n",
|
178 |
-
"artifact_dir = artifact.download()"
|
179 |
-
]
|
180 |
-
},
|
181 |
-
{
|
182 |
-
"cell_type": "code",
|
183 |
-
"execution_count": 6,
|
184 |
-
"metadata": {
|
185 |
-
"id": "_6-XKK40oEfP",
|
186 |
-
"scrolled": true
|
187 |
-
},
|
188 |
-
"outputs": [
|
189 |
-
{
|
190 |
-
"name": "stderr",
|
191 |
-
"output_type": "stream",
|
192 |
-
"text": [
|
193 |
-
"/home/tmabraham/dalle-mini/src/transformers/src/transformers/models/bart/configuration_bart.py:180: UserWarning: Please make sure the config includes `forced_bos_token_id=16384` in future versions.The config can simply be saved and uploaded again to be fixed.\n",
|
194 |
-
" warnings.warn(\n",
|
195 |
-
"INFO:absl:Starting the local TPU driver.\n",
|
196 |
-
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
197 |
-
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n"
|
198 |
-
]
|
199 |
-
}
|
200 |
-
],
|
201 |
-
"source": [
|
202 |
-
"# create our model and initialize it randomly\n",
|
203 |
-
"model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)"
|
204 |
-
]
|
205 |
-
},
|
206 |
-
{
|
207 |
-
"cell_type": "code",
|
208 |
-
"execution_count": 7,
|
209 |
-
"metadata": {},
|
210 |
-
"outputs": [],
|
211 |
-
"source": [
|
212 |
-
"model.config.forced_bos_token_id = None"
|
213 |
-
]
|
214 |
-
},
|
215 |
-
{
|
216 |
-
"cell_type": "code",
|
217 |
-
"execution_count": 8,
|
218 |
-
"metadata": {
|
219 |
-
"colab": {
|
220 |
-
"base_uri": "https://localhost:8080/"
|
221 |
-
},
|
222 |
-
"id": "Jz032w73nHEf",
|
223 |
-
"outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
|
224 |
-
},
|
225 |
-
"outputs": [
|
226 |
-
{
|
227 |
-
"data": {
|
228 |
-
"text/plain": [
|
229 |
-
"(1, 16385)"
|
230 |
-
]
|
231 |
-
},
|
232 |
-
"execution_count": 8,
|
233 |
-
"metadata": {},
|
234 |
-
"output_type": "execute_result"
|
235 |
-
}
|
236 |
-
],
|
237 |
-
"source": [
|
238 |
-
"# we verify that the shape has not been modified\n",
|
239 |
-
"model.params['final_logits_bias'].shape"
|
240 |
-
]
|
241 |
-
},
|
242 |
-
{
|
243 |
-
"cell_type": "markdown",
|
244 |
-
"metadata": {
|
245 |
-
"id": "zLl24Ez5t7x1"
|
246 |
-
},
|
247 |
-
"source": [
|
248 |
-
"## Inference"
|
249 |
-
]
|
250 |
-
},
|
251 |
-
{
|
252 |
-
"cell_type": "code",
|
253 |
-
"execution_count": 9,
|
254 |
-
"metadata": {
|
255 |
-
"id": "XLLA2NK3uDQr"
|
256 |
-
},
|
257 |
-
"outputs": [],
|
258 |
-
"source": [
|
259 |
-
"tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
|
260 |
-
]
|
261 |
-
},
|
262 |
-
{
|
263 |
-
"cell_type": "code",
|
264 |
-
"execution_count": 10,
|
265 |
-
"metadata": {},
|
266 |
-
"outputs": [],
|
267 |
-
"source": [
|
268 |
-
"input_text = ['I enjoy walking with my cute dog']*8"
|
269 |
-
]
|
270 |
-
},
|
271 |
-
{
|
272 |
-
"cell_type": "code",
|
273 |
-
"execution_count": 11,
|
274 |
-
"metadata": {
|
275 |
-
"id": "P32mJJSbrU1F"
|
276 |
-
},
|
277 |
-
"outputs": [],
|
278 |
-
"source": [
|
279 |
-
"input_ids_test = tokenizer(input_text, return_tensors='jax')"
|
280 |
-
]
|
281 |
-
},
|
282 |
-
{
|
283 |
-
"cell_type": "code",
|
284 |
-
"execution_count": 12,
|
285 |
-
"metadata": {},
|
286 |
-
"outputs": [
|
287 |
-
{
|
288 |
-
"data": {
|
289 |
-
"text/plain": [
|
290 |
-
"{'input_ids': DeviceArray([[ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
|
291 |
-
" 2],\n",
|
292 |
-
" [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
|
293 |
-
" 2],\n",
|
294 |
-
" [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
|
295 |
-
" 2],\n",
|
296 |
-
" [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
|
297 |
-
" 2],\n",
|
298 |
-
" [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
|
299 |
-
" 2],\n",
|
300 |
-
" [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
|
301 |
-
" 2],\n",
|
302 |
-
" [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
|
303 |
-
" 2],\n",
|
304 |
-
" [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
|
305 |
-
" 2]], dtype=int32), 'attention_mask': DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
306 |
-
" [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
307 |
-
" [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
308 |
-
" [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
309 |
-
" [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
310 |
-
" [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
311 |
-
" [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
312 |
-
" [1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)}"
|
313 |
-
]
|
314 |
-
},
|
315 |
-
"execution_count": 12,
|
316 |
-
"metadata": {},
|
317 |
-
"output_type": "execute_result"
|
318 |
-
}
|
319 |
-
],
|
320 |
-
"source": [
|
321 |
-
"input_ids_test"
|
322 |
-
]
|
323 |
-
},
|
324 |
-
{
|
325 |
-
"cell_type": "code",
|
326 |
-
"execution_count": 13,
|
327 |
-
"metadata": {
|
328 |
-
"id": "C7cHbIHruELT"
|
329 |
-
},
|
330 |
-
"outputs": [],
|
331 |
-
"source": [
|
332 |
-
"greedy_output = model.generate(input_ids_test['input_ids'], max_length=257)"
|
333 |
-
]
|
334 |
-
},
|
335 |
-
{
|
336 |
-
"cell_type": "code",
|
337 |
-
"execution_count": 14,
|
338 |
-
"metadata": {},
|
339 |
-
"outputs": [
|
340 |
-
{
|
341 |
-
"data": {
|
342 |
-
"text/plain": [
|
343 |
-
"(8, 257)"
|
344 |
-
]
|
345 |
-
},
|
346 |
-
"execution_count": 14,
|
347 |
-
"metadata": {},
|
348 |
-
"output_type": "execute_result"
|
349 |
-
}
|
350 |
-
],
|
351 |
-
"source": [
|
352 |
-
"greedy_output[0].shape"
|
353 |
-
]
|
354 |
-
},
|
355 |
-
{
|
356 |
-
"cell_type": "code",
|
357 |
-
"execution_count": 15,
|
358 |
-
"metadata": {
|
359 |
-
"colab": {
|
360 |
-
"base_uri": "https://localhost:8080/"
|
361 |
-
},
|
362 |
-
"id": "jYugh9cOuwc9",
|
363 |
-
"outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
|
364 |
-
},
|
365 |
-
"outputs": [
|
366 |
-
{
|
367 |
-
"data": {
|
368 |
-
"text/plain": [
|
369 |
-
"DeviceArray([[16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
|
370 |
-
" [16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
|
371 |
-
" [16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
|
372 |
-
" ...,\n",
|
373 |
-
" [16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
|
374 |
-
" [16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
|
375 |
-
" [16384, 10042, 10042, ..., 10042, 10042, 9570]], dtype=int32)"
|
376 |
-
]
|
377 |
-
},
|
378 |
-
"execution_count": 15,
|
379 |
-
"metadata": {},
|
380 |
-
"output_type": "execute_result"
|
381 |
-
}
|
382 |
-
],
|
383 |
-
"source": [
|
384 |
-
"greedy_output[0]"
|
385 |
-
]
|
386 |
-
},
|
387 |
-
{
|
388 |
-
"cell_type": "code",
|
389 |
-
"execution_count": 16,
|
390 |
-
"metadata": {},
|
391 |
-
"outputs": [
|
392 |
-
{
|
393 |
-
"data": {
|
394 |
-
"text/plain": [
|
395 |
-
"DeviceArray([16384, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
396 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
397 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
398 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
399 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
400 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
401 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
402 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
403 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
404 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
405 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
406 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
407 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
408 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
409 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
410 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
411 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
412 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
413 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
414 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
415 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
416 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
417 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
418 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
419 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
420 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
421 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
422 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
423 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
424 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
425 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
426 |
-
" 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
|
427 |
-
" 9570], dtype=int32)"
|
428 |
-
]
|
429 |
-
},
|
430 |
-
"execution_count": 16,
|
431 |
-
"metadata": {},
|
432 |
-
"output_type": "execute_result"
|
433 |
-
}
|
434 |
-
],
|
435 |
-
"source": [
|
436 |
-
"greedy_output[0][0]"
|
437 |
-
]
|
438 |
-
},
|
439 |
-
{
|
440 |
-
"cell_type": "markdown",
|
441 |
-
"metadata": {},
|
442 |
-
"source": [
|
443 |
-
"# VGAN Jax"
|
444 |
-
]
|
445 |
-
},
|
446 |
-
{
|
447 |
-
"cell_type": "code",
|
448 |
-
"execution_count": 17,
|
449 |
-
"metadata": {},
|
450 |
-
"outputs": [],
|
451 |
-
"source": [
|
452 |
-
"import io\n",
|
453 |
-
"\n",
|
454 |
-
"import requests\n",
|
455 |
-
"from PIL import Image\n",
|
456 |
-
"import numpy as np\n",
|
457 |
-
"\n",
|
458 |
-
"import torch\n",
|
459 |
-
"import torchvision.transforms as T\n",
|
460 |
-
"import torchvision.transforms.functional as TF\n",
|
461 |
-
"from torchvision.transforms import InterpolationMode"
|
462 |
-
]
|
463 |
-
},
|
464 |
-
{
|
465 |
-
"cell_type": "code",
|
466 |
-
"execution_count": 18,
|
467 |
-
"metadata": {},
|
468 |
-
"outputs": [],
|
469 |
-
"source": [
|
470 |
-
"from modeling_flax_vqgan import VQModel"
|
471 |
-
]
|
472 |
-
},
|
473 |
-
{
|
474 |
-
"cell_type": "code",
|
475 |
-
"execution_count": 19,
|
476 |
-
"metadata": {},
|
477 |
-
"outputs": [],
|
478 |
-
"source": [
|
479 |
-
"def custom_to_pil(x):\n",
|
480 |
-
" x = np.clip(x, 0., 1.)\n",
|
481 |
-
" x = (255*x).astype(np.uint8)\n",
|
482 |
-
" x = Image.fromarray(x)\n",
|
483 |
-
" if not x.mode == \"RGB\":\n",
|
484 |
-
" x = x.convert(\"RGB\")\n",
|
485 |
-
" return x"
|
486 |
-
]
|
487 |
-
},
|
488 |
-
{
|
489 |
-
"cell_type": "code",
|
490 |
-
"execution_count": 20,
|
491 |
-
"metadata": {
|
492 |
-
"colab": {
|
493 |
-
"base_uri": "https://localhost:8080/"
|
494 |
-
},
|
495 |
-
"id": "Jz032w73nHEf",
|
496 |
-
"outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49",
|
497 |
-
"scrolled": true
|
498 |
-
},
|
499 |
-
"outputs": [
|
500 |
-
{
|
501 |
-
"name": "stdout",
|
502 |
-
"output_type": "stream",
|
503 |
-
"text": [
|
504 |
-
"Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
|
505 |
-
]
|
506 |
-
}
|
507 |
-
],
|
508 |
-
"source": [
|
509 |
-
"model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
|
510 |
-
]
|
511 |
-
},
|
512 |
-
{
|
513 |
-
"cell_type": "code",
|
514 |
-
"execution_count": 21,
|
515 |
-
"metadata": {},
|
516 |
-
"outputs": [],
|
517 |
-
"source": [
|
518 |
-
"def get_images(indices, model):\n",
|
519 |
-
" indices = indices[:, 1:]\n",
|
520 |
-
" print(indices.shape)\n",
|
521 |
-
" img = model.decode_code(indices)\n",
|
522 |
-
" return img"
|
523 |
-
]
|
524 |
-
},
|
525 |
-
{
|
526 |
-
"cell_type": "code",
|
527 |
-
"execution_count": 22,
|
528 |
-
"metadata": {},
|
529 |
-
"outputs": [
|
530 |
-
{
|
531 |
-
"name": "stdout",
|
532 |
-
"output_type": "stream",
|
533 |
-
"text": [
|
534 |
-
"(1, 256)\n",
|
535 |
-
"Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
|
536 |
-
]
|
537 |
-
},
|
538 |
-
{
|
539 |
-
"data": {
|
540 |
-
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAAtSElEQVR4nO1d6Y4jOXKOyFJ3zwzWXgP2An4Dv/8D+QF8ADvrxczsTFcp/CMzmWRcDDJPqfShu0pi8giScZNS4cevf/3tx28/wjcAJKABBoI7ASEMCEgAAAQAAEhwB0CYgfOD9A4BaHqxVHvhk2NkBZpfU1auvs5bRTqn7GcOEkMDwB1gALgDfcD9HT4Qvt6+f337AkBw/wAkGIaJiREABxhZHgGIgAAw52xBMc7l+GL/F9YjwkWyTmL6USpywRhf3wEGQIK3G+AXgBvcfxre7gQEAG9wJ4Bh4nhEoAEGhDsBIgwA91lSJx1Pmvjh/JOJIAhSoFQM6oRJSDmJOu5KoTG+fCQ7fhqwfUiFcmUS1NWwtghKxgN3YTvVo08uq5kqD2Jgmn8O8A7wFf/4+A6j34MfAIjwls2EAIaxyegOZQPT6DLh/DMrL5ZEEM/oeeEFSPoNAQJakbcUqq14w7Q2ACLQO9Ad7l/hdgP4ABgAhpn1px5nR2jxbUpalkelx4+i5oghLMDPBib90gzmkDXjqtjqp1qfUaJ6z6xnx2JKslXrLUhCwfTyhT6WVTBTMqppoMmrJwQaAEfjcAOCAd5mnwcBhjvcZ84GmmLfRL0VBBclQvcjL9bhro9SOaokzgXZb32vy3na6q5V67cS2UQbaeWqS9AaQEa2HzP2G3mYZneGAG7DXAeBAIjgPpoh4n4cCyqIFiEpplRqOHXuwakF6/T57q06VO2B6TUWgMXjCqZ/QbS1HGAW48k+I1Y34nJU1Y0fxIXgbKdPvYwnncqzWz/mdeBGBDCp/OTxJ7ee8faIez6KJe5ob9TW6Ot4fcgr94vsR63ExDWyNUqwGqtQ1ewdZqlnqYW/nWvXgpuSXNoOU6GSCIhgSLVuOCABEgHNrj5NgewyYen9V9mORJ31HOcgomz9OlKLOEo4SBIYTVRvWHUJZLml8GXnV3cNXdCyCSPy2FLhK8sbz9/SlBTKtTzcgAaEpPAJAHE+85p9MlWXBFwvm6DII9TYsRqQ+XY6TrQq5R0sJZtYy9dKs0PMQ7P+jFz/4qyfze0tuUipM5sIRKCx5hgH3wgIgBCG+7LphHP+NBs170oOzdldUtGxYT47ygrZ66rLq/YktYbqfDNxa4PLnR2Komm0oEnz1QhDt3tbiSvKPbUdHIU9FnORPJEUXtN00gUEcAe4IWJ2wwHnJWEuV04xez2KZrGUpS98ij5q5RXLC3ec745ROshYOVZfh+tjmKaGk5qfi9KdmjocP1MSl7EvzYoebvPYlHqcT3lhdoFSJ6pCV9X9RcCVXJ8WvR4cm9w5P6mKrZgkAtkV639+rQQsKQVv0RlE6nooRqTZIEyHAzcWaNBUCWA6GaiOW6r7uQ+cgumoHO/Dl7zXp+B+EPNYOy3VEWHKVd1Ia4Ols5jblKwJiRdOZw2Q4geFDCxvbwCkTgOLFchluLom468GM6pVsqLH9PQJUh07oXllIltFGlurrapuSc5SYeomsxBsFVcJt5belig53H8dc3cyXFOj7ZyME0ExITzFM9xmxI6khV/HSaTEO8M5SFi/7GONm6/Ip6CA24EtQex38UQdLpLGkd6s5Td3TypivtnoHayZzyhOZHU4KbrcUd/VXWQ5lhq5VPyaHfX4x07U8IaABtsCOOHQfpEk16llNkD4cgttaQF9j9R5ai1lN9dqTmUP1MRDa8PIU2UdrPyGZfjUHXLKzYGVKuR8FkVvUH+CAHAjItTC9CxOzyHdvw1NPNepYg6W0jVCrKKCX265NL5uskasrknVHfDXVlWa3e6WOZZv4/xClgKRXbXQSqlDXMJLtY9pUMttyE91xxW8ZX2g075K3worvy1ynk4oLEbJ7HJrcjcgP+WwrKJ0sZrsieqhOSDxU+2feRkJ0qKCqMNqWjSoU5Yro3erjldjoCQDy1uDNrN9/noAuGX0jBc8qYv7odwSx+IdAC9sYIxjuNi8fv5SZbHAxvuP4krWr8eI5D0sd77q/VdHD07ZaYXygW8cSWyNMZ1CF6SjLUnKTSUqN2SWrXW5+wqmYILB4ma1eIf5W1XTegowNlDVpY60lY8kGV2eyXqYwwY1KPLTKt11txYKAYZR4ZcRBvnNwriEDARirA1AhmCoxAQ7zLulcgg5XIQkB8eskj2ybkwisZS0AMQ7ROXVXHOQn8BBW3p88yQerpagLsiNR/HiLLDFlDxNNcbddU2bZIY1XDEmQ8MupVs85ImNGi9NNdNViNy1IlYplUgBMoiVPnZ6wyQu6iwlw50cMHDlrDk+EoiHpVZzdcRIyeaQc9lcxlY7UVVW0VsRABaBjQUeBoxNbtlbfeLOcgSnSkXvy82LmYl5WMNsljRBunhp41r9SDSxSBX+cCtFqxW5e3PAuNV13gpsXq1e3KjWh1n9V8nuSI8ujWUYR6nTsmb2nLpHlIh4zEFEljgumR2dByGXN95wP7RsKPl82aiw9OfjXVFkypK9SI27l0bVedLf1WhUHbhT/IgjRlrZudQyrM/qDh5gnVoijUpMVA2ZMuhTv6WHVi/yUbdesYYg8ymJn+yRWoKNmx4BCUvGnF454lY0qP2oe0La50gIWq6m28xkRVU+72yC6XPqkcs/xBU36k9mjAKQFs5ZJoKSC6IrmoHtnuTTuBDWBukjMNItK7HeqiWtcDSP2rlTWCEmsMIyA5K/UGVjmy0gl7RqbOMILpSX4aS48LbVHYggX7AgZzc5jrV2lhKFcspHhqlBbEtS4S6R8sjS7pY/q8LUaZbcOLBYuTqY1slUV94GTWth9rZSqKXmzN045sU63BoYxNdAah0ZATkajhGbGq735CNxsj8Km521wLKQ9eCznDWoLNdVaocYtALL3+ztbaQpW/Lq/lV5tZvEDUHiRaRya514YROC6x/spBo7xfuMD6qWs1hujQ+ht03KNNfxaquUfB9yAvlLTinN35mVavLKrYYw9ZL/kw0v6I68sBrLbiucZDTxP2pLS7Von4MsksEOmQ8rbRNNaPA32IZ2LDwgK/fCeWBuSaRBs93wLaAiAJkvNP4qwgyDHXlEJW0eCg8xFwkWylD54oUnxWIBwnyd+yDhAexixwIQrwsA9esZBAuvE2Z8nzpnYSaW/0C8wBYN8cJDoc0ClH+qoiIFbp/Lye/g9KRF6FYIpXtm0vqo+t7qzhnyhaeAqmSd2swBqSQqI2clN6jIk/kBsVk80hzyj84sR5Iqf4PB4mqKziHuhcfHxDL5H9ryPaJSL+t5VCepY8UA1ojM58kTSXr+nCyiXFz28OmF/UGzZz9+ef8BHq8bA5h1y3c1R6Xy2WImPIdM+4UHAO6e+yg6H9TSmZTlpwrTaiwP9ZnITOsr3flCLyzHmZeq+VDvqxGdZulnUt6CCipb5ISaf1/serDo2slcbWIIH96a0mQHgrNgDKaqcpO/vBigQ/dbFcuv2i1infP2Sppayv4B8/eK93sJ7CbcD+VEHuw0hQoPorWpXu4kcvK/BNMJQ+6WYgSwbrSetzPyOI6xe0Fa1Ru8BpasoiaxDyMGcv0Z+jZCuNkE2V2g+upYuR1jaRk/obIpU+x/MAiyuyKWDxagSq2i8d4RoHwxk+7RGOUxbALVeLJjAtmCTAoNY1+PvjRSR41lMJ2mh0XAyyCSM9SUbjtQvNgJOXuTLDKGZ9rn4ibNJG81t0xKiiJBcBCta4mZuqKyvHXok9DHPXnigOlrv39ZgXOwmlsLEMOGuIpUtMTBnSNg5BxghBNG+BWaOkT37eoRdg25WfTpEMN8JMshqVbwHwWnajmDzn5Wt3pDF2t3TahbAOmT+P5PB5Lek1a5W6jKnpyRVy6rdcfE0Vgd67SNY5i6UPfUHd15ZJkOOdqOKrzDQBBnN8MCSJ8EaoO1kELFL3voXpCm7KncHgfVqUhlr8a+rFqrXmxbCb93FD9BkB4e2vLW5NxXgdjv7YHQfhVibbU0MGWN/JBUGyHCTFLfWZobS96ITMXZbCYYB8GfXv420LCuzUjxmoyKa7Gv1+oJgDNyNWTzUY3vGsmJ1GSaWO57Jdff59iu2X8SP3NKWm0K6we1QqMuuAOq7oJTcg4sd6NqAfQVXnkwNHtvfXmuuO/nq+FQP/m29znyKk3WKNKagOBEtcOqxUGjptEkf06yGvJqe0ORvXadyxeycg4gQxtZAbXXEWhaJ9JB64qbtFORaCuq5ROXQbmabWRrxVP0Lk1Be2iVS84OSp1fQXBMhWyjm9zktPKI1WewlWqO+IzqB2G+UlFfdyHYQV9uhDfJGL/4Zm2pa6u77qdufea2DkNkWxKFlgp3UlRNWKPdDIp6+siG5uSso2pCVQAwWUC5H9utkQN1ZGgRmKItc/+xEIYoNUytSeKqpLA+1WokeMeyRYyGDgtQJUwV3ZhWb90yc9B2LNSR0V9zDLCappViYmU7ojWdANl6CoYIQuNqpCEwJnJs0MRwudGwvDIqX0SGaEI15NgB+mgtNCjTbbgLVA0JoEWVnoGF9lzrR2jecEaOM2MVqmRYLlC8gvPI2mW13NcdZcVmiXMatGdieixAW69VlVPui793cmEje21xzvKIRBsIi27cAEUgyaXAGsoww/L7g4UqYU3lENr6cK0ZNrXLVoa70ys2C4DcsEaoIRGbCNo8bSkmKDk5b1s0UR3u+DqqNftEQurDiAMG7hJY1XzD0orVWqBNDLSqRRC8ZlJrLUC7mpH7QvaSSnZ1tlLthJdg9i/fB4zta5DKIILapOpzshL2D9wlzvupVlBfgNiYqv8WRMTrhuiC67WqX4uyPPZrCGUWVGcMvm5gHCufyrdmb8xeRLZH3YnEOpLbNnSWLEhHX65CxMpVg1ppka2AJLbxobWpVlpv0FIQ3BoOFTXcljmPrfFL49AJlqYnz+WZzWr9Btuuh7UTVlSav908OeEMuu1SEH+3hluUJVSyQGLIZpRkJlWLykO9kT8o66FOIWUcP5dgXpJ9v9022EMe1vS5Bz05N6lBmGOsg2BGBnlZT3+UKT5QBaCu80WvVXfELCuGrSJiEnkdNuGsI07PAVr8uSFj+gsuKVOF1SB4mVRVLW+BDs9ZhnycZiEWyhAX3KorQA15R+T63vd/TlpbyxMsMFS+wTBDUxAcznDsCyMY5YnRF/c78HWSFQys93/iiG2fbpBu8gv/O5JxUA9PsFQZR0BGvfLhi/V9tDrD+QvNg2d1j2MG0wKs7zg2iW5O63CKKLWEzEa/NP1OqCY33HZ9IPGzE6sEoHEG3Znx1hkizOyeJ+lTXy8x2BbBBOgaFeg+aWMqRsYQpGsjplEPT5w6/YkEJm0yInlJwinoWHNDwDrVPxOXqAVotG4qWKa/emLYbS6mdG9e9MJK+Oa7a6vW8m7j6KiOOsjWmxxbOBQEEDm4tBQA5s9ean5DVFfyyJA2P8Fv2uJOCyBbdlSI1SXhBckK3tPj8wufAZtwf66YdoLds5732+bzABtE4w1wQh9SXr2wP3o9+5X3GqKntA6G1uaS77bWtexAt3bAMEFx/V84EvGVLy8LrUkObbDbw9xXFNbJX1dTB/n06vmil7t/BQTXP/mo61i4OGvu7ipqATZir3g31WqjvicAovlP0r/wcFixa8G7w/rjVNhsAVag59iibEK5rpep/RceCFvs18JRTm9+FiX6iTALjdOIW6rKLUMqa7y4/xFB7HdzaxIcFeoq55x0EtyZP2mRGys/6ZzSKodiwYP3Fx4C8053b2Y9R2J1PTYbagakPv4O1cVnV7LXL1//GRHaVY17eqJf1QU6BkHuzbW8Mr1gZvSFB8GabNBa3X2wAERQ+fRFcvpfMvAsWOPMovu2jmHl+FtA3tpk09AD3bPJfuEEiE1fmwscAIB9h8+emrX7RsmL258bJ5jzOQgW2I3X2OdSWtu+8MQ4TcEpArAbr/m3OK2aVskLz4Qo023OnMpViCvx2kvxfwqsYLm1HKJchdiN6Viw65xoKZ7SlcTyhY2xguXW8oXyvUC7sVqezMm/qyF/bZIQ+ZDYCw8K6t/Y+jVI//Ewfy/Qkc4GO96Sb81mL5foWdErA2uvKacg+Bjd6ng7ITJeFuApEddrggHWXoQfzmYpNr5+ZfrTfNjlc8xSgMRhVAey9g3LOJzkVFjX3ZRD38/EFJ/Xxeu4Jjm3K+4zo6jl93yFu0AehdYV6mfEp7FzOvrmXg9h/csFDX8mdVPUZ0tqNvSZURXztAzPphCk2rYgEiFUlj/SZbh6+PJ0fL/y5hauj/meFZUvybExWI13RmjMp9tqdvH96ebXC+q/CuFlESOdnhsDPJ2Kr+PF9ArWnQQ3hLwSaz8UvwIVUrf53qOHgRoBqwugfoT6M+ORDsJyeCInd/XZNafl3Fv78nKlRvB1Q4C2c4BtyWlB9VT4M8O3Bms/BnU1HDyBnM/OygJVjPvnFgX11iDU7lA9LuTfqWtsn72OrEhe5wanZYHaNu85tjpDxMVjV2CdRXvaI4IaCADGP3uehcP1RUg1LmUBHg4bzoI0jR69J2vUUft8bGiX4Zy3OnJFcoWrEA+KzW9nb3XIlZj+s1kDH6h+sdpLALohP6JT+UxPDRFtjeKF39uznRxrk1lxmfq8u0A6SGOrC4MZgaadsBx6Z+r5XfFqtc8DylW7O/lizcdXwzU/ZnVBkgSUm9vikZPN9GXcOhOouvWPsHIlWih2UgUT3GVVLkpEPw9w2Lo6bHU24kSpH/9v8o66F+CSK7cLdO+uZgE4lA/Fq/g86+oi7rSol7Q6tH4THtHjbyI459bFHmIoftKfD4gh5t51XS9/rO/nVSySO6Zi9c/+rRzlWlitWzGcKiZZIWoBdkV8Dw/fbcn6Ulod7dMnAzmXOyKH2QtJ26OAoOUvXWvL0TZrLM3xgLFPIx+5tM5Yx+4w0yzlaSNP41jxru+WOIwebCgHfVAxyH61oU3LUKk2rvD16Auuas4tXnRyoMzpV70Xv/MmqKOzkuvss4Lxe4HaFwLZq2oPD3AQdhkxkFyrpnfS21RYvbEje9sQjxgKB4FiZYlZj1Ypv4QASKIfYQMljaQ92jA88OEYmQdA0BOPnAQ3ycAlBEDa70tb6wmRUCXnyDyYXsOmcQvT0fNpCH41olaJH2/5i8tyGpcQgEdDh4Sy3NEaVrO2+MJHiNthKwOXlukSAvBom7bq9tVGPTwf+nibq4PqerFhzvpATAF5tEPli0eGGhi09hBp69dZ38Nl0XydrTgH2JycNei4NHN5RDxSv/nKgOFR0EFqEVAFl0k5B7gaLq+IWg8u0TZme8/18mu5Cj3irZwDXEdLPNd2pYsS7O111vsJMK1n7jZH2iSc9dWIE57R50mI78jeqK7rgy78srbdLKT8lcjjcQUeaYHjl7PTKKnvr3lWdUGSdBB/10y5uA0KcOL81au9D6qOxEVDuAy7X4GGbYDKu+bZ5Q3O+gsxjw5V90juPwwPqzQaISzA2pjqilmga4PKTJrD5f5l6by8aQvVm54vPTai+T7FFQXgIn6DAedCqFqNXAbtMOLqJweeAXRIikw5Cb4arm3O1Vs9KXhR706fxaCO4F3RaKz/Q5EduKIAXDgUZtk2LIlVb0E3Zai3hTPi5bh/E0T+0gwT/eGCBzMXJElA3nOGGuFHTuvy67cC7tzqE+cnwRd0uK9GTwbVZc8DYhRRMmRP4RDWfFTuj9C9njeo3InT/1K8giuRJGlRt0B+6oVK7k8vLuzfnYwQb28Yu0yX4U5Rt5EPU10DHcuDtW3aNQB93FsPBKuJ67jzfU4QfGEnZyWkjyQvwx3D/eSScUFY6YQGROSHB8F9I+2Ey2qnETZ5zjmX6h3tRA6KCo+QUJiwDZ2tonMtAbgeppPa2ke21WW3Tnk30cErD4+fFpR+xHChc4BL7dJMDMJimwnK76WpMTJT9pv7Hn3xSY7rXNjeBvOKR2eEp1iA6uWBE6HRpn4jMSu3gNm/bYGae0PKK73CpbSNhEee49XVZrWos4QTBMA5nT8XyWUpKfGu8ZxO8wjpoQmimRBuGIqcAO1SdAQ80USXigHOtQDhAJer85P4SFJlATN2R1F+BbsrUaGKnHcB6LdBT1cIpxMAAEJBSKIUMg+nXMlpajSoN/Ouj0NJXQTg9BU6nQAAiFGxlQx05ChJiCi436t8jUUNYXTQQwTbLlDbfPEifyAj4Urb5a/LEgqnf+FLWjJy7TE7rMJj5fw1qF98q0BUc9a2NiTADTF1+8irtw3YBZ7WawVEWp5h7i1/bXnwZlbHP0/ALA4uc1aAc+q2ekPjbGxAWgcT3+ZFuQT3n03ExCQ5PzViyjMYpjxxoHpJAQSX1nliTHs7pz/zRPDa3A+u+thnPAC42ofiTyQGl98beIWG/Pg3SZ066hAwf4Yw4vtEHYyzEfLjNlSUy0nwFTTE2RZgxDbeYLaeuXOyeJwdHW6CRz/+ddZBfeQzdvUqxDV48lBsxmnqufJsZOTl0MIXyqQwT3qm+lQm+APUGMcBD4dtb9UmAfjsQfCBk7e8/HwLcuHJ+Z41jFON6pCXROgUbKtZXOgk+DLwT1X9CkE0iRvLt67ChbnfSoXtiwvdBj0RpL40K26yQ92ddDTE3F5cVQaSU3cclNug57pBZ+9N2oMr35NpxXRpjx4jERSycltNhDIB+NQBAAAYYdVWPs+ZmEX50Cz7CgRPP0IQiQSOJABXWZnz6HAOp05E9+g833qVDfawnxY2Z3+tvxBzFsTyXIdbVoUKp3zZ4GpsyY+OI4spBrjmIh1MVW24ay6Sj4NpXs+4eDzFl7MAJxKjrb5HzoXDZJTMdKldtnHcoo4LsqRB5XodL4646iLaBgQIqAdViEYN4+hX6XSLGS5ePrEio/acCNp8Yw/dr9bVszh5LDT/UPYx3K8disprvddBJZGiPUJNnhFXXblrDXApo/yqRqvlFuAmmG6DbnC62Akr53h1tC6XiERTarVpvrxyS+MDUkGtVzPMLoLYimn5QZhwhPZbMjaFZZPOkoPguB3koTY7nMpzeVBlA3FphfF0vqHaNld3G56OR2EpD1lYJYtfhWAOE+6bmpWdE9nP9saaEau3DNCYU+6WpPvSuKgewuWAuln05oAK56yoT2MfcNy1I1Ouhlvf41VWvxVidwtg3eG+LBSpjdavL6awAq3Ht4ven/vB1M9uq4pQfmfeAahlLKw6HOdfhrtOZBY2O6OmKZyZ0g8gyC70+3bZSiiJEiLNOUzqHU29mPrZa5E3jEp7Sewn4WQBmLdtmfspzk9ODxgEZHvDP18rspBosb7oaukBi9dEnGvN0GiWvXTXbTNGb709enzqfD0qArDnFcLcMb0Wcl6EUiYlT1AmNt2iy+xAZk8qmftMYjf+2Lszl32uVZ8jPsuH4o3Bj8kCoSy6CDD7OYNkhezFxL2thzVln4iT9x/Klu7KOGS83Tr70zYJUZsvVJCwShC8rUYxwkGT7tONg0gKT8UkCmn6GwKJK1LQGdoII6nXsABY8x434VRLGLKSVeOsI5LyV8G1q3414jYSXgsHaf5XeLqXMghy72d+z/98Rp7OJ9lqb7h52A26UpOP+WKsSIZuuVQ1v2aBeRt0/51LJn4p2fPYYQNkVlYNCpY4fvw9Ju/dKRFt8zVE9YUj8aIK5lVooxQG+zyLrY8cocd0gbadjKE5uCN0PPszZ4b9y4HKu3EKS6I9fxxRQhvOt7pfOcVxhD3SYvtOUmPKOUCVkuPuAok40s+AHwprERx5SAXlNNrcHlzkpwl8hHgP2JityvPCtvAsxYZB25vF+vs306D7J6UW54HmowA24gFnAh1DqOn2+SijmZu7FnnVzqxZVa2hf6qR3pJdfz30OeXSa8H8btDNT1VUzApjWR2ZYDkAQWPkWImDjZhrmngdiQ5arTSGNo50RZJIVEeO7jnxd53rz4PgI7eR2O8247x2XEvSIhTQ0smY+lATgtHeujB1j/ZJZUpIUbnE60ctB6mWxFNDa6jrbLvEABF7sRKSz0751Db6+korR01NjIofDf2TSjYXa1rWDZ376oFQvhlZc0uNEPAH8ZxkHyWpoKf/W2OjVW65lLH83OQwhydXh4yt1ZL8rXhE2bYv88Ppg2CLFFhX2coOF12Uyw9mP2G+IJ2KpcdddlgM2+ErYLlibsUC5T2a+LC2t7UDWi/DbcalaEz0gMA3h8XxKEpkw1Qx14uJ0bG4Eoc+D+EsNtqj9HO57oYK7bzzPDJJItTHVpp3YC2MOsL23LxVkoZ/K8RhHolYTbV8d0j1prqCqhuDCkvxz8ur5h8Bku8+15HMWR0wElBSOltcv7NysmI47/GBCDHRSN5x3w3K0mDZZuImp6FBGiRnWxV8krotVZ5ec1nEe+gapbyHaY2tGN3qKjyuERnxSKB1qeqy0710CSNNrV+P3i/VdpBNKJTb5rB2wB9RFZj0z/fpd0XVz8jXU11Yw46YcmGH8jLaX34irxPEBgIT7KLyofjNofVfbNAeBFhZmsy9jvYDbqvzjP5CQFWPuLrf9K/sbk1Oa1cH2CswysjBLlq/GrF5UhEfI89+7KFE1Qn6utzBwVGKNv4qEuIxhBxYy37qvXXZ86084YaR9/0LMQ5vzY+qrvCWxDAkXb65i7qWSSuDn2NpcHFWFfsA2YauGKJvdn5k50GmQX1+aOMW1fEo0uP8k6wrvzKtQgyVE6Bi55bCvD6ICmVdzPWWlsdMxm2rxF0+9MHIPdU0r4IlZhnozCzu8clmH5XvBRLocYFQvNY4j+UN9oJ6+JADDcHQuoHc3aTJl8OsnzWBoINtF4mlf+N0qnkN1sPpYdEI02e7MRI3/0gKW4y8d9IVxo4I7katGom9Vx05QqWJ7CpOl0PJGrC5VPtkNKuW1aoc6b+pWrWX3ALzAQBgaGS+bQTaSL1tN8C+yPfe8lnVjd8E1JvBYgkIlrWMQ04Zs3LUCmVlj9BItUAPoV6O/l4gg6BCCeUu08MCA9q9T8n1cQZq7K4WyretSTLsyq6xwTqcMe2NpvvzCqd9MRY576yiB0N1/w62do5iUcM0WCGfq6ZGejqhE1ZeYgS7CyRNmNN2LVmP4O08JXJfBQ2OX7k/608qvB5kuNU91gDF/S21oxejPg38BPdhY+1FRFPzsfIgCn269xKG/pOMFxqQJWx3H2iv5J7GglNZ03gpCxTof1NYVKpZtKcGGbzYnZ8J4hj3c+UU1gTQDWi6CrHjqn1KN8vKonZkOa+MjcW410XRn1gCwBPYG26FGpXLs6UXPhs2OqaM9iRjgHP0Dbsr8cJROCYMgK50fh8r0vS/ZWqKBRCD+yd5azO+qZfazbMXtsWuy0ylXW8ipYOwtrA+lzNFAEipuRfUc4eXHXhkJL5vcGlnQannQaouPpq1insGqYZyF6hRw7/Y9QUGzFRyk/9zXBiYxhgcMtv1//oD8BceFGtu141t6jLjs8d4eyI49mIBqpcu7B6ZMVnLvfnR/OuixMNCXrLo62RNywYJ7LgMR6WNy8s3xl5niS9sj7Wpkdl3z21IUyednBI9CNMub6onOBvjZQc+CcqPATRbj/LDxA3pV08AAjLFxnjp6xfWoJLNdJ61uv4Jg/3I6US1d83KWkxX78HSBi/j8EyI7KaTA9UesU8aohruDnPPFREyxibxswGCsxs+iPTi/mdE966yP7PSMEQ0CD7DuVHGfB2WPSU0PYh2Bb95GwZHdLa6dbQJ0hnJAZfZXzgJ5sbGdjxyDsCfD/nRcTs7r2fFly/zeWHnehrUf1nF/COF1p+w974WJTdFtat9VmRcRfSuyEvrPx/EnpostoWa1APdoW9UNFvtotFfB8PPh5pKjXcS6sDK86jnACguJQWDkHyUXT4H9BKDJ4C7ifzhih2vdDW+9Q7CqmPPQpy+A3BHPwWFEL+uDD0o4lnzwO23pv0vToutD8Wr8DKSWFgA+XxLHDfSC5ui5isn1jzsiyQmBM8BQlE5AlChpmXWXjZlj+rzd06FnYH7UD8gfDTgSZ9BpWwxxQ5ybXZM/n3swhQAFD8h47Pe8fMVkK9hJcv5NyYssv2xq7yS7yvz0JxWDp0quu+6+HVYfSkbKxmNqSStt9YRejhE3frxb4zf2EfF7PHQr8ksWrn91Rbq6yo31tci4ilZYztp3eq+qq5aH46M/jd3LLGiLiO5R+lTBK2EO8yM5e8DrPxAVtYcoWBMawednbXY2vsEgtodszLqUyjrpLcRPUN2TRLjqm6A1QMrDIaMwU2kWGV1xVTKDw7D4jP1O4HcAmSF/d3JLmj57ahUq7MGY6duFWNrEiWyk6abfUkjqb6dZA4pFczCov1I1pH6ILh/tkOiVLNKujuxiiOGolHkLNdj2WT5vUCLQSh+x6Bxf3ChNkbATyvgiERkCMzeolaYykmwL5ZrxKQISmFmP9m/R0A3pQQNQl7UKje3cKgWFwjLx1u5gONvx1Uosd6y6WSQVqKqUlXtWe5TK7lSHqo1H4St4+je4o5W01iOKbtldUk7VQ46wxWoilDjP8dNWiUb8m+fSYXqNufDn2PXTgCbup9EdZzzyvaJZ3rluI9neR+sh5t4VDic2ypky6A3toZsoTcQzrzTarVPwPEMMiZy1lxWzvfIXD+tlNSH8Z0qRyxcTsyqOAdhFBmpA35g4sJxVVBUWEtSR50XBFamWAoXIP53k0oL4FEVOQlmAds2SpcNAO7yUFFRNlW73J7OFzLI0IlEeZPPo/aZl2LVAiMAAWKpu0mxPMtAQ15W1tBdvVMUYSDLQbaZfunuzSFVlpWp0hsbzyzfqclhJtGX1+4GAGHLspCBQLR8+mbJT+2tda3wKpxT0ZoW2gtqk7BCwHhYEq8jyTZZJNZbUxRDYmXUnvXy9uiueE9zJyFrMlXmsWvEYix/JbK67YzdAYCAKNO9p/gcKwYl7QUVc+VQ80AypM8fgegtQnJaUed8Qk3C+L2lPtmOya5IzKINknR72vwJI4vchQAAxu254Nqtpsq39E7bK3X/RguASUbzzTheBpgZBoM1SviqQX2K2iMrnvcfpVFSIbP1mD7Aqn3hK9r80LT8UgaqU9OhRgPxcM13llqoWJCyQDOj6qAkAJQRjdkGwDQfGj89LFaIyv3mThscKxJMA9ekX1/eef/UhUMU3WIpeKICzjzNHuU0LpTMC7ooszjlWaeSxrw/zB9QfbkqjCidpDX+T05YGIsTDsWOWM7VAp4FUn1MVLSUeJk1yc1CUmkHRAiSmDUNjeZKMk66PuE+VcPST/ncvOD47IhTkQRGfNdwaxvlZHXRw/VtUgTV+JYJAKr+DJUywLYrezvqOUpik5bZt0TlaBfHSuY8qM/Z4HDbkpfLUcngxU2h9MocibzQJsPTqGRUI9HhLZcSKoP3oCyKJojaSD5GAUsbk03gUQTjWpD8xEqqMcoO3G+ybIvN9JqO/EOLp4ry64DYHG+AOLIsLgI39gLA00mTYYksjawTaIX59pB8aZa80IZ9VHtkNFRLu1DwQGa4Rq8PUznz6u9LwhQI4IbwcYfhAwaaU/sDwFDyOrtJZqt2Fg7l8dUSaJW2uCCuZvWUmT8jZFTZ9LSPtVgwvKGMqPu1Vf9K5ziJQe7O5N57+iZdRIDb+/udvsAdvhO8Idw/AL8CEsAbAMBAcC850te+ij0qy5OvpdBO9iNjtmzDOgVjN2e3G3GXU3naO53J7Mfv20Q6dGAlqvrGSnnPvOd0CYK4eqX08/brrz/Dn/8F4ctXQIJhALjP/94A3mAAuLM/vyGvFutE9c6mlr60CqtBXUOnmzpah/psazhpNe/n2tKUxPZRVE2nuhGMBnNcnKPN4ZeP39/gh6/w5Q2Gb3O9DwACuAMQwDD/JUnXlqUgYhvIvtS3LRESin+BDszx463Y+kga1CZWBfmoe0abgJHBSrwGgYepw5zxSKvAe8iIQBQ0paQwAtw+/vj7dwAEuAMAwABwm19/zMmcAUbDoBJP4oUzvajpkPVyz5QqnlIQ6vqmEtVByxuy8riaV/VWsEN/CMYKG5od1AjLdXDU83LrSYIV5o4YmfnRdADFSnOWvX376U8Ew/f5Xuj7PJsPgA+Ar3PtAYb7IgOjF7Sk+dNvlaKMXhScpAfD/lpureVURdvRNtjKVNVpt7Q8fQdJlSbZZimikucDkfcmVGonkHGz4b5Y3oA1br56OCf6CQBwFonkAt3+9OOPvwD8CvAV4PsU+wIB/AjwDgAA78V3qKeoYhk7FwPLqOfTFRgPidPDp0/ymEgBaK5T9hxOeymKdiNi1HTIBMwfLjdDbq3pTJZVnJvjHQjGIPj3t69/ALwD/DH5/fdvMADAd4AbwDvAlzkgnnX1EkxjSQ8Zmts2+ZT92oPv8+jfrwOi2ueVwy50iMmYsZnyNrMM1Dk8eyC3LJksmjeeebM0WwCYYoCfhm8E8DPQV0AE+A2GfwDcAb4AfAf4NtsB5AOPfbGSumMbRs6XOZjXVPV0I1YlHsY8J+QSg7a+NaUbhN4NU4LVIMmqgJOfw+ompkmaejokHv7nr//1O8Bv8NuvAB8AAO/fAWD+vpQ3Zf64DMTGjUw0itwC5/8kAaiOXuscWxqu9HSvDnU51qyv0XFDN1SrrXH/cgcnq4bpEQAMQAiIMADgKCe3f/3hT78D/AQ/0eTt3Gjm+3s2GIXUZOFLGPX207J8uVjawlpNEk+0+osMVOfvQPpbfiIj69Nf1VSz6E/7OOxUM5K1sZIytk+7ldbTR6+OJppMEpHMC8F9vvA23OF+A7h9G375PwACGgC/A3yZ4+BvAH8AfJ3dIWEYdQrswuDTbsjkErdQqV7cbbKmF+Eep3PUV5DTVUpIfUCaq2UOYpGpYzqM7WU58hTjseY58YEVONdmImSuDk4q/20mi0Y+v3254z8D/AT48zy9rwA/AgDAN4CP5U9pf5RcBOktOivpwWPCNRym9swcOaNzZyIFRWiUq8BM6gKjm81Dg3EKiwvx4SVdBERVIvlbSVLmoys5eI1StY9WnzYPDRZXaNYdxXEYAbwBfgACwA3ePu4A3wD+AvAd4AeAN4APgC8Av8/WAAFGWSAA9mGL7Bpz6JZoDJStisWO2a8ZKBhd1omMvlF58Zi0QiinxxidtPrt2Mvn9H3i0RzZn01j3cg+LLLlCmHprQMsiXUYH82mbAD4IKA3GN6n4124/UY/fQD8Av/7/f3f4AZ/B/oGCPDHz/D1O/z6T/DtBvcv8EaA75NZZEfC91zgAv7FtZCdkpxrr9dD1QYrPbX1deLwIjHfxyCuR9Kfpxh1NM3uECDgHd4+4GMAeiccEG4//+d//+U//v3Pv3zHb3/72+9/fPvhh/sbfvz6Tl//8eXXjx//9NP7HwBfBrrDgISICESjIMAcV2QTCHqrp4P4u/z7kQr9y4wepVqFmpmL5hcHw+JEy6SwkCB5SjGOJpjOU+t/YK5BRNjyLbTOHDfFJSPjpZQAldWn0GXetjsQIiAMBDQQwv2dAD8A8OPtDd7ubx/w/1lcuBBroJa/AAAAAElFTkSuQmCC\n",
|
541 |
-
"text/plain": [
|
542 |
-
"<PIL.Image.Image image mode=RGB size=256x256 at 0x7FA20677A400>"
|
543 |
-
]
|
544 |
-
},
|
545 |
-
"execution_count": 22,
|
546 |
-
"metadata": {},
|
547 |
-
"output_type": "execute_result"
|
548 |
-
}
|
549 |
-
],
|
550 |
-
"source": [
|
551 |
-
"custom_to_pil(np.asarray(get_images(jnp.expand_dims(greedy_output[0][0],0), model)[0]))"
|
552 |
-
]
|
553 |
-
}
|
554 |
-
],
|
555 |
-
"metadata": {
|
556 |
-
"accelerator": "TPU",
|
557 |
-
"colab": {
|
558 |
-
"collapsed_sections": [],
|
559 |
-
"machine_shape": "hm",
|
560 |
-
"name": "CustomBARTv4b-model-generate.ipynb",
|
561 |
-
"provenance": []
|
562 |
-
},
|
563 |
-
"kernelspec": {
|
564 |
-
"display_name": "Python 3",
|
565 |
-
"language": "python",
|
566 |
-
"name": "python3"
|
567 |
-
},
|
568 |
-
"language_info": {
|
569 |
-
"codemirror_mode": {
|
570 |
-
"name": "ipython",
|
571 |
-
"version": 3
|
572 |
-
},
|
573 |
-
"file_extension": ".py",
|
574 |
-
"mimetype": "text/x-python",
|
575 |
-
"name": "python",
|
576 |
-
"nbconvert_exporter": "python",
|
577 |
-
"pygments_lexer": "ipython3",
|
578 |
-
"version": "3.8.8"
|
579 |
-
}
|
580 |
-
},
|
581 |
-
"nbformat": 4,
|
582 |
-
"nbformat_minor": 1
|
583 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoding/vqgan-jax-encoding-with-captions.ipynb
DELETED
@@ -1,363 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"id": "d0b72877",
|
6 |
-
"metadata": {},
|
7 |
-
"source": [
|
8 |
-
"# vqgan-jax-encoding-with-captions"
|
9 |
-
]
|
10 |
-
},
|
11 |
-
{
|
12 |
-
"cell_type": "markdown",
|
13 |
-
"id": "875c82b3",
|
14 |
-
"metadata": {},
|
15 |
-
"source": [
|
16 |
-
"Notebook based on [vqgan-jax-reconstruction](https://colab.research.google.com/drive/1mdXXsMbV6K_LTvCh3IImRsFIWcKU5m1w?usp=sharing) by @surajpatil.\n",
|
17 |
-
"\n",
|
18 |
-
"We process a `tsv` file with `image_file` and `caption` fields, and add a `vqgan_indices` column with indices extracted from a VQGAN-JAX model."
|
19 |
-
]
|
20 |
-
},
|
21 |
-
{
|
22 |
-
"cell_type": "code",
|
23 |
-
"execution_count": 1,
|
24 |
-
"id": "3b59489e",
|
25 |
-
"metadata": {},
|
26 |
-
"outputs": [],
|
27 |
-
"source": [
|
28 |
-
"import io\n",
|
29 |
-
"\n",
|
30 |
-
"import requests\n",
|
31 |
-
"from PIL import Image\n",
|
32 |
-
"import numpy as np\n",
|
33 |
-
"from tqdm import tqdm\n",
|
34 |
-
"\n",
|
35 |
-
"import torch\n",
|
36 |
-
"import torchvision.transforms as T\n",
|
37 |
-
"import torchvision.transforms.functional as TF\n",
|
38 |
-
"from torchvision.transforms import InterpolationMode\n",
|
39 |
-
"from torch.utils.data import Dataset, DataLoader\n",
|
40 |
-
"\n",
|
41 |
-
"import jax\n",
|
42 |
-
"from jax import pmap"
|
43 |
-
]
|
44 |
-
},
|
45 |
-
{
|
46 |
-
"cell_type": "markdown",
|
47 |
-
"id": "511c3b9e",
|
48 |
-
"metadata": {},
|
49 |
-
"source": [
|
50 |
-
"## VQGAN-JAX model"
|
51 |
-
]
|
52 |
-
},
|
53 |
-
{
|
54 |
-
"cell_type": "markdown",
|
55 |
-
"id": "bb408f6c",
|
56 |
-
"metadata": {},
|
57 |
-
"source": [
|
58 |
-
"`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities."
|
59 |
-
]
|
60 |
-
},
|
61 |
-
{
|
62 |
-
"cell_type": "code",
|
63 |
-
"execution_count": 2,
|
64 |
-
"id": "2ca50dc7",
|
65 |
-
"metadata": {},
|
66 |
-
"outputs": [],
|
67 |
-
"source": [
|
68 |
-
"from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel"
|
69 |
-
]
|
70 |
-
},
|
71 |
-
{
|
72 |
-
"cell_type": "markdown",
|
73 |
-
"id": "7b60da9a",
|
74 |
-
"metadata": {},
|
75 |
-
"source": [
|
76 |
-
"We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
|
77 |
-
]
|
78 |
-
},
|
79 |
-
{
|
80 |
-
"cell_type": "code",
|
81 |
-
"execution_count": 3,
|
82 |
-
"id": "29ce8b15",
|
83 |
-
"metadata": {},
|
84 |
-
"outputs": [
|
85 |
-
{
|
86 |
-
"data": {
|
87 |
-
"application/vnd.jupyter.widget-view+json": {
|
88 |
-
"model_id": "db406bdfc5d5428eaeae1631a04989dd",
|
89 |
-
"version_major": 2,
|
90 |
-
"version_minor": 0
|
91 |
-
},
|
92 |
-
"text/plain": [
|
93 |
-
"Downloading: 0%| | 0.00/433 [00:00<?, ?B/s]"
|
94 |
-
]
|
95 |
-
},
|
96 |
-
"metadata": {},
|
97 |
-
"output_type": "display_data"
|
98 |
-
},
|
99 |
-
{
|
100 |
-
"data": {
|
101 |
-
"application/vnd.jupyter.widget-view+json": {
|
102 |
-
"model_id": "3e37f07fba6d48fca70313ae1fa8cc32",
|
103 |
-
"version_major": 2,
|
104 |
-
"version_minor": 0
|
105 |
-
},
|
106 |
-
"text/plain": [
|
107 |
-
"Downloading: 0%| | 0.00/304M [00:00<?, ?B/s]"
|
108 |
-
]
|
109 |
-
},
|
110 |
-
"metadata": {},
|
111 |
-
"output_type": "display_data"
|
112 |
-
},
|
113 |
-
{
|
114 |
-
"name": "stderr",
|
115 |
-
"output_type": "stream",
|
116 |
-
"text": [
|
117 |
-
"INFO:absl:Starting the local TPU driver.\n",
|
118 |
-
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
119 |
-
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host TPU\n"
|
120 |
-
]
|
121 |
-
},
|
122 |
-
{
|
123 |
-
"name": "stdout",
|
124 |
-
"output_type": "stream",
|
125 |
-
"text": [
|
126 |
-
"Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
|
127 |
-
]
|
128 |
-
}
|
129 |
-
],
|
130 |
-
"source": [
|
131 |
-
"model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
|
132 |
-
]
|
133 |
-
},
|
134 |
-
{
|
135 |
-
"cell_type": "markdown",
|
136 |
-
"id": "c7c4c1e6",
|
137 |
-
"metadata": {},
|
138 |
-
"source": [
|
139 |
-
"## Dataset"
|
140 |
-
]
|
141 |
-
},
|
142 |
-
{
|
143 |
-
"cell_type": "markdown",
|
144 |
-
"id": "7014a7ce",
|
145 |
-
"metadata": {},
|
146 |
-
"source": [
|
147 |
-
"We use Luke Melas-Kyriazi's `dataset.py` which reads image paths and captions from a tsv file that contains both. We only need the images for encoding."
|
148 |
-
]
|
149 |
-
},
|
150 |
-
{
|
151 |
-
"cell_type": "code",
|
152 |
-
"execution_count": 4,
|
153 |
-
"id": "85832702",
|
154 |
-
"metadata": {},
|
155 |
-
"outputs": [],
|
156 |
-
"source": [
|
157 |
-
"from dalle_mini.dataset import *"
|
158 |
-
]
|
159 |
-
},
|
160 |
-
{
|
161 |
-
"cell_type": "code",
|
162 |
-
"execution_count": 5,
|
163 |
-
"id": "81b19eca",
|
164 |
-
"metadata": {},
|
165 |
-
"outputs": [],
|
166 |
-
"source": [
|
167 |
-
"cc12m_images = '/data/CC12M/images'\n",
|
168 |
-
"cc12m_list = '/data/CC12M/images-list-clean.tsv'\n",
|
169 |
-
"# cc12m_list = '/data/CC12M/images-10000.tsv'\n",
|
170 |
-
"cc12m_output = '/data/CC12M/images-encoded.tsv'"
|
171 |
-
]
|
172 |
-
},
|
173 |
-
{
|
174 |
-
"cell_type": "code",
|
175 |
-
"execution_count": 6,
|
176 |
-
"id": "fecc9a00",
|
177 |
-
"metadata": {},
|
178 |
-
"outputs": [],
|
179 |
-
"source": [
|
180 |
-
"image_size = 256\n",
|
181 |
-
"def image_transform(image):\n",
|
182 |
-
" s = min(image.size)\n",
|
183 |
-
" r = image_size / s\n",
|
184 |
-
" s = (round(r * image.size[1]), round(r * image.size[0]))\n",
|
185 |
-
" image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
|
186 |
-
" image = TF.center_crop(image, output_size = 2 * [image_size])\n",
|
187 |
-
" image = torch.unsqueeze(T.ToTensor()(image), 0)\n",
|
188 |
-
" image = image.permute(0, 2, 3, 1).numpy()\n",
|
189 |
-
" return image"
|
190 |
-
]
|
191 |
-
},
|
192 |
-
{
|
193 |
-
"cell_type": "code",
|
194 |
-
"execution_count": 7,
|
195 |
-
"id": "4ce2211f",
|
196 |
-
"metadata": {},
|
197 |
-
"outputs": [],
|
198 |
-
"source": [
|
199 |
-
"dataset = CaptionDataset(\n",
|
200 |
-
" images_root=cc12m_images,\n",
|
201 |
-
" captions_path=cc12m_list,\n",
|
202 |
-
" image_transform=image_transform,\n",
|
203 |
-
" image_transform_type='torchvision',\n",
|
204 |
-
" include_captions=False\n",
|
205 |
-
")"
|
206 |
-
]
|
207 |
-
},
|
208 |
-
{
|
209 |
-
"cell_type": "code",
|
210 |
-
"execution_count": 8,
|
211 |
-
"id": "cc922704",
|
212 |
-
"metadata": {},
|
213 |
-
"outputs": [
|
214 |
-
{
|
215 |
-
"data": {
|
216 |
-
"text/plain": [
|
217 |
-
"8592141"
|
218 |
-
]
|
219 |
-
},
|
220 |
-
"execution_count": 8,
|
221 |
-
"metadata": {},
|
222 |
-
"output_type": "execute_result"
|
223 |
-
}
|
224 |
-
],
|
225 |
-
"source": [
|
226 |
-
"len(dataset)"
|
227 |
-
]
|
228 |
-
},
|
229 |
-
{
|
230 |
-
"cell_type": "markdown",
|
231 |
-
"id": "62ad01c3",
|
232 |
-
"metadata": {},
|
233 |
-
"source": [
|
234 |
-
"## Encoding"
|
235 |
-
]
|
236 |
-
},
|
237 |
-
{
|
238 |
-
"cell_type": "code",
|
239 |
-
"execution_count": 9,
|
240 |
-
"id": "88f36d0b",
|
241 |
-
"metadata": {},
|
242 |
-
"outputs": [],
|
243 |
-
"source": [
|
244 |
-
"def encode(model, batch):\n",
|
245 |
-
"# print(\"jitting encode function\")\n",
|
246 |
-
" _, indices = model.encode(batch)\n",
|
247 |
-
" return indices"
|
248 |
-
]
|
249 |
-
},
|
250 |
-
{
|
251 |
-
"cell_type": "code",
|
252 |
-
"execution_count": 10,
|
253 |
-
"id": "1f35f0cb",
|
254 |
-
"metadata": {},
|
255 |
-
"outputs": [],
|
256 |
-
"source": [
|
257 |
-
"def superbatch_generator(dataloader, num_tpus):\n",
|
258 |
-
" iter_loader = iter(dataloader)\n",
|
259 |
-
" for batch in iter_loader:\n",
|
260 |
-
" superbatch = [batch.squeeze(1)]\n",
|
261 |
-
" try:\n",
|
262 |
-
" for b in range(num_tpus-1):\n",
|
263 |
-
" batch = next(iter_loader)\n",
|
264 |
-
" if batch is None:\n",
|
265 |
-
" break\n",
|
266 |
-
" # Skip incomplete last batch\n",
|
267 |
-
" if batch.shape[0] == dataloader.batch_size:\n",
|
268 |
-
" superbatch.append(batch.squeeze(1))\n",
|
269 |
-
" except StopIteration:\n",
|
270 |
-
" pass\n",
|
271 |
-
" superbatch = torch.stack(superbatch, axis=0)\n",
|
272 |
-
" yield superbatch"
|
273 |
-
]
|
274 |
-
},
|
275 |
-
{
|
276 |
-
"cell_type": "code",
|
277 |
-
"execution_count": 11,
|
278 |
-
"id": "2210705b",
|
279 |
-
"metadata": {},
|
280 |
-
"outputs": [],
|
281 |
-
"source": [
|
282 |
-
"import os\n",
|
283 |
-
"\n",
|
284 |
-
"def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):\n",
|
285 |
-
" if os.path.isfile(output_tsv):\n",
|
286 |
-
" print(f\"Destination file {output_tsv} already exists, please move away.\")\n",
|
287 |
-
" return\n",
|
288 |
-
" \n",
|
289 |
-
" num_tpus = 8 \n",
|
290 |
-
" dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n",
|
291 |
-
" superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n",
|
292 |
-
" \n",
|
293 |
-
" p_encoder = pmap(lambda batch: encode(model, batch))\n",
|
294 |
-
"\n",
|
295 |
-
" # We save each superbatch to avoid reallocation of buffers as we process them.\n",
|
296 |
-
" # We keep the file open to prevent excessive file seeks.\n",
|
297 |
-
" with open(output_tsv, \"w\") as file:\n",
|
298 |
-
" iterations = len(dataset) // (batch_size * num_tpus)\n",
|
299 |
-
" for n in tqdm(range(iterations)):\n",
|
300 |
-
" superbatch = next(superbatches)\n",
|
301 |
-
" encoded = p_encoder(superbatch.numpy())\n",
|
302 |
-
" encoded = encoded.reshape(-1, encoded.shape[-1])\n",
|
303 |
-
"\n",
|
304 |
-
" # Extract fields from the dataset internal `captions` property, and save to disk\n",
|
305 |
-
" start_index = n * batch_size * num_tpus\n",
|
306 |
-
" end_index = (n+1) * batch_size * num_tpus\n",
|
307 |
-
" paths = dataset.captions[\"image_file\"][start_index:end_index].values\n",
|
308 |
-
" captions = dataset.captions[\"caption\"][start_index:end_index].values\n",
|
309 |
-
" encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
|
310 |
-
" batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n",
|
311 |
-
" batch_df.to_csv(file, sep='\\t', header=(n==0), index=None)\n",
|
312 |
-
" "
|
313 |
-
]
|
314 |
-
},
|
315 |
-
{
|
316 |
-
"cell_type": "code",
|
317 |
-
"execution_count": null,
|
318 |
-
"id": "7704863d",
|
319 |
-
"metadata": {},
|
320 |
-
"outputs": [
|
321 |
-
{
|
322 |
-
"name": "stderr",
|
323 |
-
"output_type": "stream",
|
324 |
-
"text": [
|
325 |
-
" 4%|██▋ | 621/16781 [07:09<3:02:46, 1.47it/s]"
|
326 |
-
]
|
327 |
-
}
|
328 |
-
],
|
329 |
-
"source": [
|
330 |
-
"encode_captioned_dataset(dataset, cc12m_output, batch_size=64, num_workers=16)"
|
331 |
-
]
|
332 |
-
},
|
333 |
-
{
|
334 |
-
"cell_type": "markdown",
|
335 |
-
"id": "8953dd84",
|
336 |
-
"metadata": {},
|
337 |
-
"source": [
|
338 |
-
"----"
|
339 |
-
]
|
340 |
-
}
|
341 |
-
],
|
342 |
-
"metadata": {
|
343 |
-
"kernelspec": {
|
344 |
-
"display_name": "Python 3 (ipykernel)",
|
345 |
-
"language": "python",
|
346 |
-
"name": "python3"
|
347 |
-
},
|
348 |
-
"language_info": {
|
349 |
-
"codemirror_mode": {
|
350 |
-
"name": "ipython",
|
351 |
-
"version": 3
|
352 |
-
},
|
353 |
-
"file_extension": ".py",
|
354 |
-
"mimetype": "text/x-python",
|
355 |
-
"name": "python",
|
356 |
-
"nbconvert_exporter": "python",
|
357 |
-
"pygments_lexer": "ipython3",
|
358 |
-
"version": "3.8.10"
|
359 |
-
}
|
360 |
-
},
|
361 |
-
"nbformat": 4,
|
362 |
-
"nbformat_minor": 5
|
363 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoding/vqgan-jax-encoding-yfcc100m.ipynb
DELETED
@@ -1,1136 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"id": "d0b72877",
|
6 |
-
"metadata": {},
|
7 |
-
"source": [
|
8 |
-
"# vqgan-jax-encoding-yfcc100m"
|
9 |
-
]
|
10 |
-
},
|
11 |
-
{
|
12 |
-
"cell_type": "markdown",
|
13 |
-
"id": "ba7b31e6",
|
14 |
-
"metadata": {},
|
15 |
-
"source": [
|
16 |
-
"Same as `vqgan-jax-encoding-with-captions`, but for YFCC100M.\n",
|
17 |
-
"\n",
|
18 |
-
"This dataset was prepared by @borisdayma in Json lines format."
|
19 |
-
]
|
20 |
-
},
|
21 |
-
{
|
22 |
-
"cell_type": "code",
|
23 |
-
"execution_count": 92,
|
24 |
-
"id": "3b59489e",
|
25 |
-
"metadata": {},
|
26 |
-
"outputs": [],
|
27 |
-
"source": [
|
28 |
-
"import io\n",
|
29 |
-
"\n",
|
30 |
-
"import requests\n",
|
31 |
-
"from PIL import Image\n",
|
32 |
-
"import numpy as np\n",
|
33 |
-
"from tqdm import tqdm\n",
|
34 |
-
"\n",
|
35 |
-
"import torch\n",
|
36 |
-
"import torchvision.transforms as T\n",
|
37 |
-
"import torchvision.transforms.functional as TF\n",
|
38 |
-
"from torchvision.transforms import InterpolationMode\n",
|
39 |
-
"from torch.utils.data import Dataset, DataLoader\n",
|
40 |
-
"from torchvision.datasets.folder import default_loader\n",
|
41 |
-
"import os\n",
|
42 |
-
"\n",
|
43 |
-
"import jax\n",
|
44 |
-
"from jax import pmap"
|
45 |
-
]
|
46 |
-
},
|
47 |
-
{
|
48 |
-
"cell_type": "markdown",
|
49 |
-
"id": "511c3b9e",
|
50 |
-
"metadata": {},
|
51 |
-
"source": [
|
52 |
-
"## VQGAN-JAX model"
|
53 |
-
]
|
54 |
-
},
|
55 |
-
{
|
56 |
-
"cell_type": "markdown",
|
57 |
-
"id": "bb408f6c",
|
58 |
-
"metadata": {},
|
59 |
-
"source": [
|
60 |
-
"`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities."
|
61 |
-
]
|
62 |
-
},
|
63 |
-
{
|
64 |
-
"cell_type": "code",
|
65 |
-
"execution_count": 93,
|
66 |
-
"id": "2ca50dc7",
|
67 |
-
"metadata": {},
|
68 |
-
"outputs": [],
|
69 |
-
"source": [
|
70 |
-
"from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel"
|
71 |
-
]
|
72 |
-
},
|
73 |
-
{
|
74 |
-
"cell_type": "markdown",
|
75 |
-
"id": "7b60da9a",
|
76 |
-
"metadata": {},
|
77 |
-
"source": [
|
78 |
-
"We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
|
79 |
-
]
|
80 |
-
},
|
81 |
-
{
|
82 |
-
"cell_type": "code",
|
83 |
-
"execution_count": 167,
|
84 |
-
"id": "29ce8b15",
|
85 |
-
"metadata": {},
|
86 |
-
"outputs": [
|
87 |
-
{
|
88 |
-
"name": "stdout",
|
89 |
-
"output_type": "stream",
|
90 |
-
"text": [
|
91 |
-
"Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
|
92 |
-
]
|
93 |
-
}
|
94 |
-
],
|
95 |
-
"source": [
|
96 |
-
"model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
|
97 |
-
]
|
98 |
-
},
|
99 |
-
{
|
100 |
-
"cell_type": "markdown",
|
101 |
-
"id": "c7c4c1e6",
|
102 |
-
"metadata": {},
|
103 |
-
"source": [
|
104 |
-
"## Dataset"
|
105 |
-
]
|
106 |
-
},
|
107 |
-
{
|
108 |
-
"cell_type": "code",
|
109 |
-
"execution_count": 94,
|
110 |
-
"id": "33861477",
|
111 |
-
"metadata": {},
|
112 |
-
"outputs": [],
|
113 |
-
"source": [
|
114 |
-
"import pandas as pd\n",
|
115 |
-
"from pathlib import Path"
|
116 |
-
]
|
117 |
-
},
|
118 |
-
{
|
119 |
-
"cell_type": "code",
|
120 |
-
"execution_count": 134,
|
121 |
-
"id": "81b19eca",
|
122 |
-
"metadata": {},
|
123 |
-
"outputs": [],
|
124 |
-
"source": [
|
125 |
-
"yfcc100m = Path('/home/khali/TPU-Test/YFCC100M_OpenAI_subset')\n",
|
126 |
-
"# Images are 'sharded' from the following directory\n",
|
127 |
-
"yfcc100m_images = yfcc100m/'data'/'data'/'images'\n",
|
128 |
-
"yfcc100m_metadata = yfcc100m/'metadata_YFCC100M.jsonl'\n",
|
129 |
-
"yfcc100m_output = yfcc100m/'metadata_encoded.tsv'"
|
130 |
-
]
|
131 |
-
},
|
132 |
-
{
|
133 |
-
"cell_type": "markdown",
|
134 |
-
"id": "1c58bb4a",
|
135 |
-
"metadata": {},
|
136 |
-
"source": [
|
137 |
-
"### Cleanup"
|
138 |
-
]
|
139 |
-
},
|
140 |
-
{
|
141 |
-
"cell_type": "markdown",
|
142 |
-
"id": "1a14ae3d",
|
143 |
-
"metadata": {},
|
144 |
-
"source": [
|
145 |
-
"We need to select entries with images that exist. Otherwise we can't build batches because `Dataloader` does not support `None` in batches. We use Huggingface Datasets, I understand they support threaded reading of jsonl files, and I was running out of memory when using pandas."
|
146 |
-
]
|
147 |
-
},
|
148 |
-
{
|
149 |
-
"cell_type": "code",
|
150 |
-
"execution_count": 96,
|
151 |
-
"id": "7811648c",
|
152 |
-
"metadata": {},
|
153 |
-
"outputs": [],
|
154 |
-
"source": [
|
155 |
-
"import datasets\n",
|
156 |
-
"from datasets import Dataset, load_dataset"
|
157 |
-
]
|
158 |
-
},
|
159 |
-
{
|
160 |
-
"cell_type": "code",
|
161 |
-
"execution_count": 10,
|
162 |
-
"id": "4811a230",
|
163 |
-
"metadata": {},
|
164 |
-
"outputs": [
|
165 |
-
{
|
166 |
-
"name": "stderr",
|
167 |
-
"output_type": "stream",
|
168 |
-
"text": [
|
169 |
-
"tcmalloc: large alloc 1254047744 bytes == 0xb2b08000 @ 0x7f9e78632680 0x7f9e78653824 0x585b92 0x504d56 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332\n",
|
170 |
-
"tcmalloc: large alloc 1254047744 bytes == 0xfd74e000 @ 0x7f9e78632680 0x7f9e78653824 0x590214 0x586f90 0x56e1f3 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332\n",
|
171 |
-
"tcmalloc: large alloc 5016190976 bytes == 0x148b42000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
172 |
-
"tcmalloc: large alloc 5019099136 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
173 |
-
"tcmalloc: large alloc 5019811840 bytes == 0x39f9a8000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
174 |
-
"tcmalloc: large alloc 5024571392 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
175 |
-
"tcmalloc: large alloc 5021097984 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
176 |
-
"tcmalloc: large alloc 5022818304 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
177 |
-
"tcmalloc: large alloc 5020794880 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
178 |
-
"tcmalloc: large alloc 5019451392 bytes == 0x39f9a8000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
179 |
-
"tcmalloc: large alloc 5020565504 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
180 |
-
"tcmalloc: large alloc 5012561920 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
181 |
-
"tcmalloc: large alloc 5021835264 bytes == 0x5f6cba000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
182 |
-
"tcmalloc: large alloc 5017436160 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n"
|
183 |
-
]
|
184 |
-
}
|
185 |
-
],
|
186 |
-
"source": [
|
187 |
-
"# The metadata is too bog to load into memory at once, so chopping it into chunks\n",
|
188 |
-
"chunk_size=1000000\n",
|
189 |
-
"batch_no=1\n",
|
190 |
-
"for chunk in pd.read_json(yfcc100m_metadata, orient=\"records\", lines=True,chunksize=chunk_size):\n",
|
191 |
-
" chunk.to_csv('./chunks/chunk'+str(batch_no)+'.tsv', sep=\"\\t\", index=False)\n",
|
192 |
-
" batch_no+=1"
|
193 |
-
]
|
194 |
-
},
|
195 |
-
{
|
196 |
-
"cell_type": "code",
|
197 |
-
"execution_count": 25,
|
198 |
-
"id": "46b2f083",
|
199 |
-
"metadata": {},
|
200 |
-
"outputs": [
|
201 |
-
{
|
202 |
-
"data": {
|
203 |
-
"text/html": [
|
204 |
-
"<div>\n",
|
205 |
-
"<style scoped>\n",
|
206 |
-
" .dataframe tbody tr th:only-of-type {\n",
|
207 |
-
" vertical-align: middle;\n",
|
208 |
-
" }\n",
|
209 |
-
"\n",
|
210 |
-
" .dataframe tbody tr th {\n",
|
211 |
-
" vertical-align: top;\n",
|
212 |
-
" }\n",
|
213 |
-
"\n",
|
214 |
-
" .dataframe thead th {\n",
|
215 |
-
" text-align: right;\n",
|
216 |
-
" }\n",
|
217 |
-
"</style>\n",
|
218 |
-
"<table border=\"1\" class=\"dataframe\">\n",
|
219 |
-
" <thead>\n",
|
220 |
-
" <tr style=\"text-align: right;\">\n",
|
221 |
-
" <th></th>\n",
|
222 |
-
" <th>photoid</th>\n",
|
223 |
-
" <th>uid</th>\n",
|
224 |
-
" <th>unickname</th>\n",
|
225 |
-
" <th>datetaken</th>\n",
|
226 |
-
" <th>dateuploaded</th>\n",
|
227 |
-
" <th>capturedevice</th>\n",
|
228 |
-
" <th>title</th>\n",
|
229 |
-
" <th>description</th>\n",
|
230 |
-
" <th>usertags</th>\n",
|
231 |
-
" <th>machinetags</th>\n",
|
232 |
-
" <th>...</th>\n",
|
233 |
-
" <th>licenseurl</th>\n",
|
234 |
-
" <th>serverid</th>\n",
|
235 |
-
" <th>farmid</th>\n",
|
236 |
-
" <th>secret</th>\n",
|
237 |
-
" <th>secretoriginal</th>\n",
|
238 |
-
" <th>ext</th>\n",
|
239 |
-
" <th>marker</th>\n",
|
240 |
-
" <th>key</th>\n",
|
241 |
-
" <th>title_clean</th>\n",
|
242 |
-
" <th>description_clean</th>\n",
|
243 |
-
" </tr>\n",
|
244 |
-
" </thead>\n",
|
245 |
-
" <tbody>\n",
|
246 |
-
" <tr>\n",
|
247 |
-
" <th>0</th>\n",
|
248 |
-
" <td>137943</td>\n",
|
249 |
-
" <td>48600072071@N01</td>\n",
|
250 |
-
" <td>doctor+paradox</td>\n",
|
251 |
-
" <td>2004-08-01 18:13:06.0</td>\n",
|
252 |
-
" <td>1091409186</td>\n",
|
253 |
-
" <td>NaN</td>\n",
|
254 |
-
" <td>A+Picture+Share%21</td>\n",
|
255 |
-
" <td>Antenna</td>\n",
|
256 |
-
" <td>cameraphone,cayugaheights,green,hydrant,ithaca...</td>\n",
|
257 |
-
" <td>NaN</td>\n",
|
258 |
-
" <td>...</td>\n",
|
259 |
-
" <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
|
260 |
-
" <td>1</td>\n",
|
261 |
-
" <td>1</td>\n",
|
262 |
-
" <td>1650c7cdc6</td>\n",
|
263 |
-
" <td>1650c7cdc6</td>\n",
|
264 |
-
" <td>jpg</td>\n",
|
265 |
-
" <td>0</td>\n",
|
266 |
-
" <td>d29e7c6a3028418c64eb15e3cf577c2</td>\n",
|
267 |
-
" <td>A Picture Share!</td>\n",
|
268 |
-
" <td>Antenna</td>\n",
|
269 |
-
" </tr>\n",
|
270 |
-
" <tr>\n",
|
271 |
-
" <th>1</th>\n",
|
272 |
-
" <td>1246361</td>\n",
|
273 |
-
" <td>44124324682@N01</td>\n",
|
274 |
-
" <td>mharrsch</td>\n",
|
275 |
-
" <td>2004-11-03 23:04:02.0</td>\n",
|
276 |
-
" <td>1099523042</td>\n",
|
277 |
-
" <td>NaN</td>\n",
|
278 |
-
" <td>An+ornate+Roman+urn</td>\n",
|
279 |
-
" <td>Photographed+at+the+%3Ca+href%3D%22http%3A%2F%...</td>\n",
|
280 |
-
" <td>ancient,baltimore,burial,death,empire,funeral,...</td>\n",
|
281 |
-
" <td>NaN</td>\n",
|
282 |
-
" <td>...</td>\n",
|
283 |
-
" <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
|
284 |
-
" <td>1</td>\n",
|
285 |
-
" <td>1</td>\n",
|
286 |
-
" <td>cf37054610</td>\n",
|
287 |
-
" <td>cf37054610</td>\n",
|
288 |
-
" <td>jpg</td>\n",
|
289 |
-
" <td>0</td>\n",
|
290 |
-
" <td>d29f01b149167d683f9ddde464bb3db</td>\n",
|
291 |
-
" <td>An ornate Roman urn</td>\n",
|
292 |
-
" <td>Photographed at the Walters Art Museum, Baltim...</td>\n",
|
293 |
-
" </tr>\n",
|
294 |
-
" <tr>\n",
|
295 |
-
" <th>2</th>\n",
|
296 |
-
" <td>1251599</td>\n",
|
297 |
-
" <td>51035803024@N01</td>\n",
|
298 |
-
" <td>bmitd67</td>\n",
|
299 |
-
" <td>2004-10-30 17:09:32.0</td>\n",
|
300 |
-
" <td>1099538888</td>\n",
|
301 |
-
" <td>Canon+PowerShot+S30</td>\n",
|
302 |
-
" <td>Jai+%26+Tara+on+the+Cumberland</td>\n",
|
303 |
-
" <td>Another+trip+for+the+happy+couple.</td>\n",
|
304 |
-
" <td>blue+heron,cumberland+river,jai,tara,tennessee</td>\n",
|
305 |
-
" <td>NaN</td>\n",
|
306 |
-
" <td>...</td>\n",
|
307 |
-
" <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
|
308 |
-
" <td>1</td>\n",
|
309 |
-
" <td>1</td>\n",
|
310 |
-
" <td>4a4234e32c</td>\n",
|
311 |
-
" <td>4a4234e32c</td>\n",
|
312 |
-
" <td>jpg</td>\n",
|
313 |
-
" <td>0</td>\n",
|
314 |
-
" <td>d296e9e34bdae41edb6c679ff824ab2a</td>\n",
|
315 |
-
" <td>Jai & Tara on the Cumberland</td>\n",
|
316 |
-
" <td>Another trip for the happy couple.</td>\n",
|
317 |
-
" </tr>\n",
|
318 |
-
" <tr>\n",
|
319 |
-
" <th>3</th>\n",
|
320 |
-
" <td>2348587</td>\n",
|
321 |
-
" <td>73621375@N00</td>\n",
|
322 |
-
" <td>Thom+Watson</td>\n",
|
323 |
-
" <td>2004-12-18 21:08:09.0</td>\n",
|
324 |
-
" <td>1103497228</td>\n",
|
325 |
-
" <td>SONY+DSC-W1</td>\n",
|
326 |
-
" <td>Castle+gate+-+%22lite-brited%22</td>\n",
|
327 |
-
" <td>Taken+at+the+Miracle+of+Lights+display+in+Cent...</td>\n",
|
328 |
-
" <td>bullrunpark,castle,centreville,christmas,decor...</td>\n",
|
329 |
-
" <td>NaN</td>\n",
|
330 |
-
" <td>...</td>\n",
|
331 |
-
" <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
|
332 |
-
" <td>2</td>\n",
|
333 |
-
" <td>1</td>\n",
|
334 |
-
" <td>7162c974c3</td>\n",
|
335 |
-
" <td>7162c974c3</td>\n",
|
336 |
-
" <td>jpg</td>\n",
|
337 |
-
" <td>0</td>\n",
|
338 |
-
" <td>d29ce96395848478b1e8396e44899</td>\n",
|
339 |
-
" <td>Castle gate - \"lite-brited\"</td>\n",
|
340 |
-
" <td>Taken at the Miracle of Lights display in Cent...</td>\n",
|
341 |
-
" </tr>\n",
|
342 |
-
" <tr>\n",
|
343 |
-
" <th>4</th>\n",
|
344 |
-
" <td>3516047</td>\n",
|
345 |
-
" <td>48600072071@N01</td>\n",
|
346 |
-
" <td>doctor+paradox</td>\n",
|
347 |
-
" <td>2005-01-18 16:44:18.0</td>\n",
|
348 |
-
" <td>1106084658</td>\n",
|
349 |
-
" <td>NaN</td>\n",
|
350 |
-
" <td>A+Picture+Share%21</td>\n",
|
351 |
-
" <td>Tabular</td>\n",
|
352 |
-
" <td>cameraphone,moblog,unfound</td>\n",
|
353 |
-
" <td>NaN</td>\n",
|
354 |
-
" <td>...</td>\n",
|
355 |
-
" <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
|
356 |
-
" <td>3</td>\n",
|
357 |
-
" <td>1</td>\n",
|
358 |
-
" <td>663e0d8b3d</td>\n",
|
359 |
-
" <td>663e0d8b3d</td>\n",
|
360 |
-
" <td>jpg</td>\n",
|
361 |
-
" <td>0</td>\n",
|
362 |
-
" <td>d29abf32c4e12ff881f975b70e0cec0</td>\n",
|
363 |
-
" <td>A Picture Share!</td>\n",
|
364 |
-
" <td>Tabular</td>\n",
|
365 |
-
" </tr>\n",
|
366 |
-
" <tr>\n",
|
367 |
-
" <th>...</th>\n",
|
368 |
-
" <td>...</td>\n",
|
369 |
-
" <td>...</td>\n",
|
370 |
-
" <td>...</td>\n",
|
371 |
-
" <td>...</td>\n",
|
372 |
-
" <td>...</td>\n",
|
373 |
-
" <td>...</td>\n",
|
374 |
-
" <td>...</td>\n",
|
375 |
-
" <td>...</td>\n",
|
376 |
-
" <td>...</td>\n",
|
377 |
-
" <td>...</td>\n",
|
378 |
-
" <td>...</td>\n",
|
379 |
-
" <td>...</td>\n",
|
380 |
-
" <td>...</td>\n",
|
381 |
-
" <td>...</td>\n",
|
382 |
-
" <td>...</td>\n",
|
383 |
-
" <td>...</td>\n",
|
384 |
-
" <td>...</td>\n",
|
385 |
-
" <td>...</td>\n",
|
386 |
-
" <td>...</td>\n",
|
387 |
-
" <td>...</td>\n",
|
388 |
-
" <td>...</td>\n",
|
389 |
-
" </tr>\n",
|
390 |
-
" <tr>\n",
|
391 |
-
" <th>999995</th>\n",
|
392 |
-
" <td>4648651054</td>\n",
|
393 |
-
" <td>24511045@N04</td>\n",
|
394 |
-
" <td>mtfrazier</td>\n",
|
395 |
-
" <td>2010-05-02 15:47:45.0</td>\n",
|
396 |
-
" <td>1275083371</td>\n",
|
397 |
-
" <td>Canon+EOS+50D</td>\n",
|
398 |
-
" <td>U.S.+Navy+Blue+Angels%3A+2010</td>\n",
|
399 |
-
" <td>2+May+2010%0ASunday%0ASt.+Joseph%2C+Missouri</td>\n",
|
400 |
-
" <td>NaN</td>\n",
|
401 |
-
" <td>NaN</td>\n",
|
402 |
-
" <td>...</td>\n",
|
403 |
-
" <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
|
404 |
-
" <td>4072</td>\n",
|
405 |
-
" <td>5</td>\n",
|
406 |
-
" <td>2d12d73fb0</td>\n",
|
407 |
-
" <td>dd5856ea42</td>\n",
|
408 |
-
" <td>jpg</td>\n",
|
409 |
-
" <td>0</td>\n",
|
410 |
-
" <td>60fa2911cb81eb25b356e9fee978aef</td>\n",
|
411 |
-
" <td>U.S. Navy Blue Angels: 2010</td>\n",
|
412 |
-
" <td>2 May 2010 Sunday St. Joseph, Missouri</td>\n",
|
413 |
-
" </tr>\n",
|
414 |
-
" <tr>\n",
|
415 |
-
" <th>999996</th>\n",
|
416 |
-
" <td>4652130996</td>\n",
|
417 |
-
" <td>21963865@N04</td>\n",
|
418 |
-
" <td>GRAB1.0</td>\n",
|
419 |
-
" <td>2010-05-29 19:23:10.0</td>\n",
|
420 |
-
" <td>1275200833</td>\n",
|
421 |
-
" <td>SONY+DSLR-A230</td>\n",
|
422 |
-
" <td>Attempts+on+Her+Life</td>\n",
|
423 |
-
" <td>BAPA+1+production+of+Martin+Crimp%27s+Attempts...</td>\n",
|
424 |
-
" <td>NaN</td>\n",
|
425 |
-
" <td>NaN</td>\n",
|
426 |
-
" <td>...</td>\n",
|
427 |
-
" <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
|
428 |
-
" <td>4003</td>\n",
|
429 |
-
" <td>5</td>\n",
|
430 |
-
" <td>8889121579</td>\n",
|
431 |
-
" <td>2f46599456</td>\n",
|
432 |
-
" <td>jpg</td>\n",
|
433 |
-
" <td>0</td>\n",
|
434 |
-
" <td>60f5ef5ce4c2d24566226abebd67d4</td>\n",
|
435 |
-
" <td>Attempts on Her Life</td>\n",
|
436 |
-
" <td>BAPA 1 production of Martin Crimp's Attempts o...</td>\n",
|
437 |
-
" </tr>\n",
|
438 |
-
" <tr>\n",
|
439 |
-
" <th>999997</th>\n",
|
440 |
-
" <td>4652568339</td>\n",
|
441 |
-
" <td>64025277@N00</td>\n",
|
442 |
-
" <td>1Sock</td>\n",
|
443 |
-
" <td>2010-05-13 15:38:37.0</td>\n",
|
444 |
-
" <td>1275234267</td>\n",
|
445 |
-
" <td>Canon+EOS+DIGITAL+REBEL+XT</td>\n",
|
446 |
-
" <td>Carlsbad+Caverns+3</td>\n",
|
447 |
-
" <td>%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%...</td>\n",
|
448 |
-
" <td>carlsbad,carlsbad+caverns,cave,faa,new+mexico,...</td>\n",
|
449 |
-
" <td>NaN</td>\n",
|
450 |
-
" <td>...</td>\n",
|
451 |
-
" <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
|
452 |
-
" <td>4010</td>\n",
|
453 |
-
" <td>5</td>\n",
|
454 |
-
" <td>0a1808a69e</td>\n",
|
455 |
-
" <td>cf6d348e3d</td>\n",
|
456 |
-
" <td>jpg</td>\n",
|
457 |
-
" <td>0</td>\n",
|
458 |
-
" <td>60f029482d1d1028fda5281daf498f</td>\n",
|
459 |
-
" <td>Carlsbad Caverns 3</td>\n",
|
460 |
-
" <td>♥♥♥♥♥♥♥ Interested in purchasing this photogra...</td>\n",
|
461 |
-
" </tr>\n",
|
462 |
-
" <tr>\n",
|
463 |
-
" <th>999998</th>\n",
|
464 |
-
" <td>4653110895</td>\n",
|
465 |
-
" <td>20483509@N00</td>\n",
|
466 |
-
" <td>subberculture</td>\n",
|
467 |
-
" <td>2010-05-30 15:37:05.0</td>\n",
|
468 |
-
" <td>1275245596</td>\n",
|
469 |
-
" <td>Canon+DIGITAL+IXUS+40</td>\n",
|
470 |
-
" <td>Want</td>\n",
|
471 |
-
" <td>Isn%27t+that+gorgeous%3F</td>\n",
|
472 |
-
" <td>2010,edinburgh+museum,may,phonebox,wood</td>\n",
|
473 |
-
" <td>NaN</td>\n",
|
474 |
-
" <td>...</td>\n",
|
475 |
-
" <td>http://creativecommons.org/licenses/by-sa/2.0/</td>\n",
|
476 |
-
" <td>4066</td>\n",
|
477 |
-
" <td>5</td>\n",
|
478 |
-
" <td>77c3b3a254</td>\n",
|
479 |
-
" <td>c4697e1511</td>\n",
|
480 |
-
" <td>jpg</td>\n",
|
481 |
-
" <td>0</td>\n",
|
482 |
-
" <td>60f72775f433cf8de3efaeb431866153</td>\n",
|
483 |
-
" <td>Want</td>\n",
|
484 |
-
" <td>Isn't that gorgeous?</td>\n",
|
485 |
-
" </tr>\n",
|
486 |
-
" <tr>\n",
|
487 |
-
" <th>999999</th>\n",
|
488 |
-
" <td>4655503987</td>\n",
|
489 |
-
" <td>8457193@N07</td>\n",
|
490 |
-
" <td>zackojones</td>\n",
|
491 |
-
" <td>2010-05-30 15:34:58.0</td>\n",
|
492 |
-
" <td>1275310230</td>\n",
|
493 |
-
" <td>Canon+EOS+7D</td>\n",
|
494 |
-
" <td>Summertime</td>\n",
|
495 |
-
" <td>You+gotta+love+it%21</td>\n",
|
496 |
-
" <td>georgia,savannah,united+states,us</td>\n",
|
497 |
-
" <td>NaN</td>\n",
|
498 |
-
" <td>...</td>\n",
|
499 |
-
" <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
|
500 |
-
" <td>4043</td>\n",
|
501 |
-
" <td>5</td>\n",
|
502 |
-
" <td>caff543bfe</td>\n",
|
503 |
-
" <td>f60952ac4d</td>\n",
|
504 |
-
" <td>jpg</td>\n",
|
505 |
-
" <td>0</td>\n",
|
506 |
-
" <td>60f687e11b913bce461e9525d8047e0</td>\n",
|
507 |
-
" <td>Summertime</td>\n",
|
508 |
-
" <td>You gotta love it!</td>\n",
|
509 |
-
" </tr>\n",
|
510 |
-
" </tbody>\n",
|
511 |
-
"</table>\n",
|
512 |
-
"<p>1000000 rows × 26 columns</p>\n",
|
513 |
-
"</div>"
|
514 |
-
],
|
515 |
-
"text/plain": [
|
516 |
-
" photoid uid unickname datetaken \\\n",
|
517 |
-
"0 137943 48600072071@N01 doctor+paradox 2004-08-01 18:13:06.0 \n",
|
518 |
-
"1 1246361 44124324682@N01 mharrsch 2004-11-03 23:04:02.0 \n",
|
519 |
-
"2 1251599 51035803024@N01 bmitd67 2004-10-30 17:09:32.0 \n",
|
520 |
-
"3 2348587 73621375@N00 Thom+Watson 2004-12-18 21:08:09.0 \n",
|
521 |
-
"4 3516047 48600072071@N01 doctor+paradox 2005-01-18 16:44:18.0 \n",
|
522 |
-
"... ... ... ... ... \n",
|
523 |
-
"999995 4648651054 24511045@N04 mtfrazier 2010-05-02 15:47:45.0 \n",
|
524 |
-
"999996 4652130996 21963865@N04 GRAB1.0 2010-05-29 19:23:10.0 \n",
|
525 |
-
"999997 4652568339 64025277@N00 1Sock 2010-05-13 15:38:37.0 \n",
|
526 |
-
"999998 4653110895 20483509@N00 subberculture 2010-05-30 15:37:05.0 \n",
|
527 |
-
"999999 4655503987 8457193@N07 zackojones 2010-05-30 15:34:58.0 \n",
|
528 |
-
"\n",
|
529 |
-
" dateuploaded capturedevice \\\n",
|
530 |
-
"0 1091409186 NaN \n",
|
531 |
-
"1 1099523042 NaN \n",
|
532 |
-
"2 1099538888 Canon+PowerShot+S30 \n",
|
533 |
-
"3 1103497228 SONY+DSC-W1 \n",
|
534 |
-
"4 1106084658 NaN \n",
|
535 |
-
"... ... ... \n",
|
536 |
-
"999995 1275083371 Canon+EOS+50D \n",
|
537 |
-
"999996 1275200833 SONY+DSLR-A230 \n",
|
538 |
-
"999997 1275234267 Canon+EOS+DIGITAL+REBEL+XT \n",
|
539 |
-
"999998 1275245596 Canon+DIGITAL+IXUS+40 \n",
|
540 |
-
"999999 1275310230 Canon+EOS+7D \n",
|
541 |
-
"\n",
|
542 |
-
" title \\\n",
|
543 |
-
"0 A+Picture+Share%21 \n",
|
544 |
-
"1 An+ornate+Roman+urn \n",
|
545 |
-
"2 Jai+%26+Tara+on+the+Cumberland \n",
|
546 |
-
"3 Castle+gate+-+%22lite-brited%22 \n",
|
547 |
-
"4 A+Picture+Share%21 \n",
|
548 |
-
"... ... \n",
|
549 |
-
"999995 U.S.+Navy+Blue+Angels%3A+2010 \n",
|
550 |
-
"999996 Attempts+on+Her+Life \n",
|
551 |
-
"999997 Carlsbad+Caverns+3 \n",
|
552 |
-
"999998 Want \n",
|
553 |
-
"999999 Summertime \n",
|
554 |
-
"\n",
|
555 |
-
" description \\\n",
|
556 |
-
"0 Antenna \n",
|
557 |
-
"1 Photographed+at+the+%3Ca+href%3D%22http%3A%2F%... \n",
|
558 |
-
"2 Another+trip+for+the+happy+couple. \n",
|
559 |
-
"3 Taken+at+the+Miracle+of+Lights+display+in+Cent... \n",
|
560 |
-
"4 Tabular \n",
|
561 |
-
"... ... \n",
|
562 |
-
"999995 2+May+2010%0ASunday%0ASt.+Joseph%2C+Missouri \n",
|
563 |
-
"999996 BAPA+1+production+of+Martin+Crimp%27s+Attempts... \n",
|
564 |
-
"999997 %E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%... \n",
|
565 |
-
"999998 Isn%27t+that+gorgeous%3F \n",
|
566 |
-
"999999 You+gotta+love+it%21 \n",
|
567 |
-
"\n",
|
568 |
-
" usertags machinetags ... \\\n",
|
569 |
-
"0 cameraphone,cayugaheights,green,hydrant,ithaca... NaN ... \n",
|
570 |
-
"1 ancient,baltimore,burial,death,empire,funeral,... NaN ... \n",
|
571 |
-
"2 blue+heron,cumberland+river,jai,tara,tennessee NaN ... \n",
|
572 |
-
"3 bullrunpark,castle,centreville,christmas,decor... NaN ... \n",
|
573 |
-
"4 cameraphone,moblog,unfound NaN ... \n",
|
574 |
-
"... ... ... ... \n",
|
575 |
-
"999995 NaN NaN ... \n",
|
576 |
-
"999996 NaN NaN ... \n",
|
577 |
-
"999997 carlsbad,carlsbad+caverns,cave,faa,new+mexico,... NaN ... \n",
|
578 |
-
"999998 2010,edinburgh+museum,may,phonebox,wood NaN ... \n",
|
579 |
-
"999999 georgia,savannah,united+states,us NaN ... \n",
|
580 |
-
"\n",
|
581 |
-
" licenseurl serverid farmid \\\n",
|
582 |
-
"0 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
|
583 |
-
"1 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
|
584 |
-
"2 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
|
585 |
-
"3 http://creativecommons.org/licenses/by-nc-sa/2.0/ 2 1 \n",
|
586 |
-
"4 http://creativecommons.org/licenses/by-nc-sa/2.0/ 3 1 \n",
|
587 |
-
"... ... ... ... \n",
|
588 |
-
"999995 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4072 5 \n",
|
589 |
-
"999996 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4003 5 \n",
|
590 |
-
"999997 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4010 5 \n",
|
591 |
-
"999998 http://creativecommons.org/licenses/by-sa/2.0/ 4066 5 \n",
|
592 |
-
"999999 http://creativecommons.org/licenses/by-nc-sa/2.0/ 4043 5 \n",
|
593 |
-
"\n",
|
594 |
-
" secret secretoriginal ext marker \\\n",
|
595 |
-
"0 1650c7cdc6 1650c7cdc6 jpg 0 \n",
|
596 |
-
"1 cf37054610 cf37054610 jpg 0 \n",
|
597 |
-
"2 4a4234e32c 4a4234e32c jpg 0 \n",
|
598 |
-
"3 7162c974c3 7162c974c3 jpg 0 \n",
|
599 |
-
"4 663e0d8b3d 663e0d8b3d jpg 0 \n",
|
600 |
-
"... ... ... ... ... \n",
|
601 |
-
"999995 2d12d73fb0 dd5856ea42 jpg 0 \n",
|
602 |
-
"999996 8889121579 2f46599456 jpg 0 \n",
|
603 |
-
"999997 0a1808a69e cf6d348e3d jpg 0 \n",
|
604 |
-
"999998 77c3b3a254 c4697e1511 jpg 0 \n",
|
605 |
-
"999999 caff543bfe f60952ac4d jpg 0 \n",
|
606 |
-
"\n",
|
607 |
-
" key title_clean \\\n",
|
608 |
-
"0 d29e7c6a3028418c64eb15e3cf577c2 A Picture Share! \n",
|
609 |
-
"1 d29f01b149167d683f9ddde464bb3db An ornate Roman urn \n",
|
610 |
-
"2 d296e9e34bdae41edb6c679ff824ab2a Jai & Tara on the Cumberland \n",
|
611 |
-
"3 d29ce96395848478b1e8396e44899 Castle gate - \"lite-brited\" \n",
|
612 |
-
"4 d29abf32c4e12ff881f975b70e0cec0 A Picture Share! \n",
|
613 |
-
"... ... ... \n",
|
614 |
-
"999995 60fa2911cb81eb25b356e9fee978aef U.S. Navy Blue Angels: 2010 \n",
|
615 |
-
"999996 60f5ef5ce4c2d24566226abebd67d4 Attempts on Her Life \n",
|
616 |
-
"999997 60f029482d1d1028fda5281daf498f Carlsbad Caverns 3 \n",
|
617 |
-
"999998 60f72775f433cf8de3efaeb431866153 Want \n",
|
618 |
-
"999999 60f687e11b913bce461e9525d8047e0 Summertime \n",
|
619 |
-
"\n",
|
620 |
-
" description_clean \n",
|
621 |
-
"0 Antenna \n",
|
622 |
-
"1 Photographed at the Walters Art Museum, Baltim... \n",
|
623 |
-
"2 Another trip for the happy couple. \n",
|
624 |
-
"3 Taken at the Miracle of Lights display in Cent... \n",
|
625 |
-
"4 Tabular \n",
|
626 |
-
"... ... \n",
|
627 |
-
"999995 2 May 2010 Sunday St. Joseph, Missouri \n",
|
628 |
-
"999996 BAPA 1 production of Martin Crimp's Attempts o... \n",
|
629 |
-
"999997 ♥♥♥♥♥♥♥ Interested in purchasing this photogra... \n",
|
630 |
-
"999998 Isn't that gorgeous? \n",
|
631 |
-
"999999 You gotta love it! \n",
|
632 |
-
"\n",
|
633 |
-
"[1000000 rows x 26 columns]"
|
634 |
-
]
|
635 |
-
},
|
636 |
-
"execution_count": 25,
|
637 |
-
"metadata": {},
|
638 |
-
"output_type": "execute_result"
|
639 |
-
}
|
640 |
-
],
|
641 |
-
"source": [
|
642 |
-
"# looking up at a chunk\n",
|
643 |
-
"pd.read_csv(\"./chunks/chunk1.tsv\", sep=\"\\t\")"
|
644 |
-
]
|
645 |
-
},
|
646 |
-
{
|
647 |
-
"cell_type": "code",
|
648 |
-
"execution_count": 98,
|
649 |
-
"id": "c51c5597",
|
650 |
-
"metadata": {},
|
651 |
-
"outputs": [
|
652 |
-
{
|
653 |
-
"data": {
|
654 |
-
"text/html": [
|
655 |
-
"<div>\n",
|
656 |
-
"<style scoped>\n",
|
657 |
-
" .dataframe tbody tr th:only-of-type {\n",
|
658 |
-
" vertical-align: middle;\n",
|
659 |
-
" }\n",
|
660 |
-
"\n",
|
661 |
-
" .dataframe tbody tr th {\n",
|
662 |
-
" vertical-align: top;\n",
|
663 |
-
" }\n",
|
664 |
-
"\n",
|
665 |
-
" .dataframe thead th {\n",
|
666 |
-
" text-align: right;\n",
|
667 |
-
" }\n",
|
668 |
-
"</style>\n",
|
669 |
-
"<table border=\"1\" class=\"dataframe\">\n",
|
670 |
-
" <thead>\n",
|
671 |
-
" <tr style=\"text-align: right;\">\n",
|
672 |
-
" <th></th>\n",
|
673 |
-
" <th>key</th>\n",
|
674 |
-
" <th>title_clean</th>\n",
|
675 |
-
" <th>description_clean</th>\n",
|
676 |
-
" <th>ext</th>\n",
|
677 |
-
" </tr>\n",
|
678 |
-
" </thead>\n",
|
679 |
-
" <tbody>\n",
|
680 |
-
" <tr>\n",
|
681 |
-
" <th>0</th>\n",
|
682 |
-
" <td>d29e7c6a3028418c64eb15e3cf577c2</td>\n",
|
683 |
-
" <td>A Picture Share!</td>\n",
|
684 |
-
" <td>Antenna</td>\n",
|
685 |
-
" <td>jpg</td>\n",
|
686 |
-
" </tr>\n",
|
687 |
-
" <tr>\n",
|
688 |
-
" <th>1</th>\n",
|
689 |
-
" <td>d29f01b149167d683f9ddde464bb3db</td>\n",
|
690 |
-
" <td>An ornate Roman urn</td>\n",
|
691 |
-
" <td>Photographed at the Walters Art Museum, Baltim...</td>\n",
|
692 |
-
" <td>jpg</td>\n",
|
693 |
-
" </tr>\n",
|
694 |
-
" <tr>\n",
|
695 |
-
" <th>2</th>\n",
|
696 |
-
" <td>d296e9e34bdae41edb6c679ff824ab2a</td>\n",
|
697 |
-
" <td>Jai & Tara on the Cumberland</td>\n",
|
698 |
-
" <td>Another trip for the happy couple.</td>\n",
|
699 |
-
" <td>jpg</td>\n",
|
700 |
-
" </tr>\n",
|
701 |
-
" <tr>\n",
|
702 |
-
" <th>3</th>\n",
|
703 |
-
" <td>d29ce96395848478b1e8396e44899</td>\n",
|
704 |
-
" <td>Castle gate - \"lite-brited\"</td>\n",
|
705 |
-
" <td>Taken at the Miracle of Lights display in Cent...</td>\n",
|
706 |
-
" <td>jpg</td>\n",
|
707 |
-
" </tr>\n",
|
708 |
-
" <tr>\n",
|
709 |
-
" <th>4</th>\n",
|
710 |
-
" <td>d29abf32c4e12ff881f975b70e0cec0</td>\n",
|
711 |
-
" <td>A Picture Share!</td>\n",
|
712 |
-
" <td>Tabular</td>\n",
|
713 |
-
" <td>jpg</td>\n",
|
714 |
-
" </tr>\n",
|
715 |
-
" </tbody>\n",
|
716 |
-
"</table>\n",
|
717 |
-
"</div>"
|
718 |
-
],
|
719 |
-
"text/plain": [
|
720 |
-
" key title_clean \\\n",
|
721 |
-
"0 d29e7c6a3028418c64eb15e3cf577c2 A Picture Share! \n",
|
722 |
-
"1 d29f01b149167d683f9ddde464bb3db An ornate Roman urn \n",
|
723 |
-
"2 d296e9e34bdae41edb6c679ff824ab2a Jai & Tara on the Cumberland \n",
|
724 |
-
"3 d29ce96395848478b1e8396e44899 Castle gate - \"lite-brited\" \n",
|
725 |
-
"4 d29abf32c4e12ff881f975b70e0cec0 A Picture Share! \n",
|
726 |
-
"\n",
|
727 |
-
" description_clean ext \n",
|
728 |
-
"0 Antenna jpg \n",
|
729 |
-
"1 Photographed at the Walters Art Museum, Baltim... jpg \n",
|
730 |
-
"2 Another trip for the happy couple. jpg \n",
|
731 |
-
"3 Taken at the Miracle of Lights display in Cent... jpg \n",
|
732 |
-
"4 Tabular jpg "
|
733 |
-
]
|
734 |
-
},
|
735 |
-
"execution_count": 98,
|
736 |
-
"metadata": {},
|
737 |
-
"output_type": "execute_result"
|
738 |
-
}
|
739 |
-
],
|
740 |
-
"source": [
|
741 |
-
"# Looking at a chunk with only the relevant columns that we need\n",
|
742 |
-
"df = pd.read_csv(\"./chunks/chunk1.tsv\", sep=\"\\t\")[[\"key\", \"title_clean\", \"description_clean\", \"ext\"]]\n",
|
743 |
-
"df.head()"
|
744 |
-
]
|
745 |
-
},
|
746 |
-
{
|
747 |
-
"cell_type": "markdown",
|
748 |
-
"id": "cc1668f8",
|
749 |
-
"metadata": {},
|
750 |
-
"source": [
|
751 |
-
"### Grabbing each chunks from the folder, cleaning it up, only taking the entries which image exist and appending it to the global df"
|
752 |
-
]
|
753 |
-
},
|
754 |
-
{
|
755 |
-
"cell_type": "code",
|
756 |
-
"execution_count": null,
|
757 |
-
"id": "abbcccf3",
|
758 |
-
"metadata": {},
|
759 |
-
"outputs": [],
|
760 |
-
"source": [
|
761 |
-
"# the function that helps us to decide whether an image with certain id exists in storage, we only take the ones that we have the images for\n",
|
762 |
-
"def image_exists(item):\n",
|
763 |
-
" name, _, _, ext, _ = item\n",
|
764 |
-
" root=str(yfcc100m_images)\n",
|
765 |
-
" image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(\".\"+ext)\n",
|
766 |
-
" if image_path.exists():\n",
|
767 |
-
" return True\n",
|
768 |
-
" else:\n",
|
769 |
-
" return None"
|
770 |
-
]
|
771 |
-
},
|
772 |
-
{
|
773 |
-
"cell_type": "code",
|
774 |
-
"execution_count": 86,
|
775 |
-
"id": "44fa86ab",
|
776 |
-
"metadata": {},
|
777 |
-
"outputs": [],
|
778 |
-
"source": [
|
779 |
-
"# This cell does it all, grabs each chunk, cleans it up based on image existing condition, etc.\n",
|
780 |
-
"global_df = pd.DataFrame()\n",
|
781 |
-
"chunks_dir = \"./chunks\"\n",
|
782 |
-
"for filename in os.listdir(chunks_dir):\n",
|
783 |
-
" df = pd.read_csv(f\"./chunks/{str(filename)}\", sep=\"\\t\")[[\"key\", \"title_clean\", \"description_clean\", \"ext\"]]\n",
|
784 |
-
" df['caption'] = df[\"title_clean\"]+\". \"+df['description_clean']\n",
|
785 |
-
" df['is_exist'] = df.apply(image_exists, axis=1)\n",
|
786 |
-
" df = df.dropna()[[\"key\", \"caption\"]]\n",
|
787 |
-
" df.columns = ['image_file', 'caption']\n",
|
788 |
-
" global_df = global_df.append(df, ignore_index=True)"
|
789 |
-
]
|
790 |
-
},
|
791 |
-
{
|
792 |
-
"cell_type": "code",
|
793 |
-
"execution_count": 89,
|
794 |
-
"id": "45024fdc",
|
795 |
-
"metadata": {},
|
796 |
-
"outputs": [],
|
797 |
-
"source": [
|
798 |
-
"# saving the tsv to disk\n",
|
799 |
-
"global_df.to_csv('./chunks/YFCC_subset_clean.tsv', sep=\"\\t\", index=False)"
|
800 |
-
]
|
801 |
-
},
|
802 |
-
{
|
803 |
-
"cell_type": "code",
|
804 |
-
"execution_count": 101,
|
805 |
-
"id": "dca4eb73",
|
806 |
-
"metadata": {},
|
807 |
-
"outputs": [],
|
808 |
-
"source": [
|
809 |
-
"# loading the tsv from disk (for explicitness, also my electricity was gone, glad it happened after I saved to the disk :( )\n",
|
810 |
-
"\n",
|
811 |
-
"dataset = pd.read_csv(f\"./chunks/YFCC_subset_clean.tsv\", sep=\"\\t\")"
|
812 |
-
]
|
813 |
-
},
|
814 |
-
{
|
815 |
-
"cell_type": "code",
|
816 |
-
"execution_count": 153,
|
817 |
-
"id": "a511264a",
|
818 |
-
"metadata": {},
|
819 |
-
"outputs": [],
|
820 |
-
"source": [
|
821 |
-
"\"\"\"\n",
|
822 |
-
"Luke Melas-Kyriazi's dataset.py's modified version for YFCC\n",
|
823 |
-
"\"\"\"\n",
|
824 |
-
"import warnings\n",
|
825 |
-
"from typing import Optional, Callable\n",
|
826 |
-
"from pathlib import Path\n",
|
827 |
-
"import numpy as np\n",
|
828 |
-
"import torch\n",
|
829 |
-
"import pandas as pd\n",
|
830 |
-
"from torch.utils.data import Dataset\n",
|
831 |
-
"from torchvision.datasets.folder import default_loader\n",
|
832 |
-
"from PIL import ImageFile\n",
|
833 |
-
"from PIL.Image import DecompressionBombWarning\n",
|
834 |
-
"ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
|
835 |
-
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
836 |
-
"warnings.filterwarnings(\"ignore\", category=DecompressionBombWarning)\n",
|
837 |
-
"\n",
|
838 |
-
"\n",
|
839 |
-
"class CaptionDataset(Dataset):\n",
|
840 |
-
" \"\"\"\n",
|
841 |
-
" A PyTorch Dataset class for (image, texts) tasks. Note that this dataset \n",
|
842 |
-
" returns the raw text rather than tokens. This is done on purpose, because\n",
|
843 |
-
" it's easy to tokenize a batch of text after loading it from this dataset.\n",
|
844 |
-
" \"\"\"\n",
|
845 |
-
"\n",
|
846 |
-
" def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None, \n",
|
847 |
-
" image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',\n",
|
848 |
-
" include_captions: bool = True):\n",
|
849 |
-
" \"\"\"\n",
|
850 |
-
" :param images_root: folder where images are stored\n",
|
851 |
-
" :param captions_path: path to csv that maps image filenames to captions\n",
|
852 |
-
" :param image_transform: image transform pipeline\n",
|
853 |
-
" :param text_transform: image transform pipeline\n",
|
854 |
-
" :param image_transform_type: image transform type, either `torchvision` or `albumentations`\n",
|
855 |
-
" :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.\n",
|
856 |
-
" \"\"\"\n",
|
857 |
-
"\n",
|
858 |
-
" # Base path for images\n",
|
859 |
-
" self.images_root = Path(images_root)\n",
|
860 |
-
"\n",
|
861 |
-
" # Load captions as DataFrame\n",
|
862 |
-
" self.captions = pd.read_csv(f\"./chunks/YFCC_subset_clean.tsv\", sep=\"\\t\")\n",
|
863 |
-
" self.captions['image_file'] = self.captions['image_file'].astype(str)\n",
|
864 |
-
"\n",
|
865 |
-
" # PyTorch transformation pipeline for the image (normalizing, etc.)\n",
|
866 |
-
" self.text_transform = text_transform\n",
|
867 |
-
" self.image_transform = image_transform\n",
|
868 |
-
" self.image_transform_type = image_transform_type.lower()\n",
|
869 |
-
" assert self.image_transform_type in ['torchvision', 'albumentations']\n",
|
870 |
-
"\n",
|
871 |
-
" # Total number of datapoints\n",
|
872 |
-
" self.size = len(self.captions)\n",
|
873 |
-
"\n",
|
874 |
-
" # Return image+captions or just images\n",
|
875 |
-
" self.include_captions = include_captions\n",
|
876 |
-
" \n",
|
877 |
-
" def image_exists(item):\n",
|
878 |
-
" name, caption = item\n",
|
879 |
-
" root=str(self.images_root)\n",
|
880 |
-
" image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(\".jpg\")\n",
|
881 |
-
"\n",
|
882 |
-
" return image_path.exists()\n",
|
883 |
-
"\n",
|
884 |
-
" def verify_that_all_images_exist(self):\n",
|
885 |
-
" for image_file in self.captions['image_file']:\n",
|
886 |
-
" if not image_exists:\n",
|
887 |
-
" print(f'file does not exist: {p}')\n",
|
888 |
-
"\n",
|
889 |
-
" def _get_raw_image(self, i):\n",
|
890 |
-
" name = self.captions.iloc[i]['image_file']\n",
|
891 |
-
" image_path = (Path(self.images_root)/name[0:3]/name[3:6]/name).with_suffix(\".jpg\")\n",
|
892 |
-
" image = default_loader(image_path)\n",
|
893 |
-
" return image\n",
|
894 |
-
"\n",
|
895 |
-
" def _get_raw_text(self, i):\n",
|
896 |
-
" return self.captions.iloc[i]['caption']\n",
|
897 |
-
"\n",
|
898 |
-
" def __getitem__(self, i):\n",
|
899 |
-
" image = self._get_raw_image(i)\n",
|
900 |
-
" caption = self._get_raw_text(i)\n",
|
901 |
-
" if self.image_transform is not None:\n",
|
902 |
-
" if self.image_transform_type == 'torchvision':\n",
|
903 |
-
" image = self.image_transform(image)\n",
|
904 |
-
" elif self.image_transform_type == 'albumentations':\n",
|
905 |
-
" image = self.image_transform(image=np.array(image))['image']\n",
|
906 |
-
" else:\n",
|
907 |
-
" raise NotImplementedError(f\"{self.image_transform_type=}\")\n",
|
908 |
-
" return {'image': image, 'text': caption} if self.include_captions else image\n",
|
909 |
-
"\n",
|
910 |
-
" def __len__(self):\n",
|
911 |
-
" return self.size\n",
|
912 |
-
"\n",
|
913 |
-
"\n",
|
914 |
-
"if __name__ == \"__main__\":\n",
|
915 |
-
" import albumentations as A\n",
|
916 |
-
" from albumentations.pytorch import ToTensorV2\n",
|
917 |
-
" from transformers import AutoTokenizer\n",
|
918 |
-
" \n",
|
919 |
-
"\n",
|
920 |
-
" images_root = \"/home/khali/TPU-Test/YFCC100M_OpenAI_subset/data/data/images\"\n",
|
921 |
-
" captions_path = './YFCC_subset_clean.tsv'\n",
|
922 |
-
" image_size = 256\n",
|
923 |
-
" \n",
|
924 |
-
" # Create transforms\n",
|
925 |
-
" def image_transform(image):\n",
|
926 |
-
" s = min(image.size)\n",
|
927 |
-
" r = image_size / s\n",
|
928 |
-
" s = (round(r * image.size[1]), round(r * image.size[0]))\n",
|
929 |
-
" image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
|
930 |
-
" image = TF.center_crop(image, output_size = 2 * [image_size])\n",
|
931 |
-
" image = torch.unsqueeze(T.ToTensor()(image), 0)\n",
|
932 |
-
" image = image.permute(0, 2, 3, 1).numpy()\n",
|
933 |
-
" return image\n",
|
934 |
-
" \n",
|
935 |
-
" # Create dataset\n",
|
936 |
-
" dataset = CaptionDataset(\n",
|
937 |
-
" images_root=images_root,\n",
|
938 |
-
" captions_path=captions_path,\n",
|
939 |
-
" image_transform=image_transform,\n",
|
940 |
-
" image_transform_type='torchvision',\n",
|
941 |
-
" include_captions=False\n",
|
942 |
-
" )"
|
943 |
-
]
|
944 |
-
},
|
945 |
-
{
|
946 |
-
"cell_type": "code",
|
947 |
-
"execution_count": 155,
|
948 |
-
"id": "cc922704",
|
949 |
-
"metadata": {},
|
950 |
-
"outputs": [
|
951 |
-
{
|
952 |
-
"data": {
|
953 |
-
"text/plain": [
|
954 |
-
"2483316"
|
955 |
-
]
|
956 |
-
},
|
957 |
-
"execution_count": 155,
|
958 |
-
"metadata": {},
|
959 |
-
"output_type": "execute_result"
|
960 |
-
}
|
961 |
-
],
|
962 |
-
"source": [
|
963 |
-
"len(dataset)"
|
964 |
-
]
|
965 |
-
},
|
966 |
-
{
|
967 |
-
"cell_type": "code",
|
968 |
-
"execution_count": 156,
|
969 |
-
"id": "6e47ba46",
|
970 |
-
"metadata": {},
|
971 |
-
"outputs": [],
|
972 |
-
"source": [
|
973 |
-
"dataloader = DataLoader(dataset, batch_size=32, num_workers=4)"
|
974 |
-
]
|
975 |
-
},
|
976 |
-
{
|
977 |
-
"cell_type": "code",
|
978 |
-
"execution_count": 1,
|
979 |
-
"id": "c8a130eb",
|
980 |
-
"metadata": {},
|
981 |
-
"outputs": [],
|
982 |
-
"source": [
|
983 |
-
"# looking at a batch\n",
|
984 |
-
"next(iter(dataloader))"
|
985 |
-
]
|
986 |
-
},
|
987 |
-
{
|
988 |
-
"cell_type": "code",
|
989 |
-
"execution_count": null,
|
990 |
-
"id": "c192fd44",
|
991 |
-
"metadata": {},
|
992 |
-
"outputs": [],
|
993 |
-
"source": [
|
994 |
-
"# import matplotlib.pyplot as plt\n",
|
995 |
-
"# for tensor_image, _ in dataloader:\n",
|
996 |
-
"# print(tensor_image)\n",
|
997 |
-
"# plt.imshow(tensor_image.permute(1, 2, 0))\n",
|
998 |
-
"# break"
|
999 |
-
]
|
1000 |
-
},
|
1001 |
-
{
|
1002 |
-
"cell_type": "markdown",
|
1003 |
-
"id": "62ad01c3",
|
1004 |
-
"metadata": {},
|
1005 |
-
"source": [
|
1006 |
-
"## Encoding"
|
1007 |
-
]
|
1008 |
-
},
|
1009 |
-
{
|
1010 |
-
"cell_type": "code",
|
1011 |
-
"execution_count": 158,
|
1012 |
-
"id": "88f36d0b",
|
1013 |
-
"metadata": {},
|
1014 |
-
"outputs": [],
|
1015 |
-
"source": [
|
1016 |
-
"def encode(model, batch):\n",
|
1017 |
-
"# print(\"jitting encode function\")\n",
|
1018 |
-
" _, indices = model.encode(batch)\n",
|
1019 |
-
" return indices"
|
1020 |
-
]
|
1021 |
-
},
|
1022 |
-
{
|
1023 |
-
"cell_type": "code",
|
1024 |
-
"execution_count": 160,
|
1025 |
-
"id": "1f35f0cb",
|
1026 |
-
"metadata": {},
|
1027 |
-
"outputs": [],
|
1028 |
-
"source": [
|
1029 |
-
"def superbatch_generator(dataloader, num_tpus):\n",
|
1030 |
-
" iter_loader = iter(dataloader)\n",
|
1031 |
-
" for batch in iter_loader:\n",
|
1032 |
-
" superbatch = [batch.squeeze(1)]\n",
|
1033 |
-
" try:\n",
|
1034 |
-
" for b in range(num_tpus-1):\n",
|
1035 |
-
" batch = next(iter_loader)\n",
|
1036 |
-
" if batch is None:\n",
|
1037 |
-
" break\n",
|
1038 |
-
" # Skip incomplete last batch\n",
|
1039 |
-
" if batch.shape[0] == dataloader.batch_size:\n",
|
1040 |
-
" superbatch.append(batch.squeeze(1))\n",
|
1041 |
-
" except StopIteration:\n",
|
1042 |
-
" pass\n",
|
1043 |
-
" superbatch = torch.stack(superbatch, axis=0)\n",
|
1044 |
-
" yield superbatch"
|
1045 |
-
]
|
1046 |
-
},
|
1047 |
-
{
|
1048 |
-
"cell_type": "code",
|
1049 |
-
"execution_count": 170,
|
1050 |
-
"id": "2210705b",
|
1051 |
-
"metadata": {},
|
1052 |
-
"outputs": [],
|
1053 |
-
"source": [
|
1054 |
-
"import os\n",
|
1055 |
-
"\n",
|
1056 |
-
"def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):\n",
|
1057 |
-
" if os.path.isfile(output_tsv):\n",
|
1058 |
-
" print(f\"Destination file {output_tsv} already exists, please move away.\")\n",
|
1059 |
-
" return\n",
|
1060 |
-
" \n",
|
1061 |
-
" num_tpus = 8 \n",
|
1062 |
-
" dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n",
|
1063 |
-
" superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n",
|
1064 |
-
" \n",
|
1065 |
-
" p_encoder = pmap(lambda batch: encode(model, batch))\n",
|
1066 |
-
"\n",
|
1067 |
-
" # We save each superbatch to avoid reallocation of buffers as we process them.\n",
|
1068 |
-
" # We keep the file open to prevent excessive file seeks.\n",
|
1069 |
-
" with open(output_tsv, \"w\") as file:\n",
|
1070 |
-
" iterations = len(dataset) // (batch_size * num_tpus)\n",
|
1071 |
-
" for n in tqdm(range(iterations)):\n",
|
1072 |
-
" superbatch = next(superbatches)\n",
|
1073 |
-
" encoded = p_encoder(superbatch.numpy())\n",
|
1074 |
-
" encoded = encoded.reshape(-1, encoded.shape[-1])\n",
|
1075 |
-
"\n",
|
1076 |
-
" # Extract fields from the dataset internal `captions` property, and save to disk\n",
|
1077 |
-
" start_index = n * batch_size * num_tpus\n",
|
1078 |
-
" end_index = (n+1) * batch_size * num_tpus\n",
|
1079 |
-
" paths = dataset.captions[\"image_file\"][start_index:end_index].values\n",
|
1080 |
-
" captions = dataset.captions[\"caption\"][start_index:end_index].values\n",
|
1081 |
-
" encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
|
1082 |
-
" batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n",
|
1083 |
-
" batch_df.to_csv(file, sep='\\t', header=(n==0), index=None)"
|
1084 |
-
]
|
1085 |
-
},
|
1086 |
-
{
|
1087 |
-
"cell_type": "code",
|
1088 |
-
"execution_count": 171,
|
1089 |
-
"id": "7704863d",
|
1090 |
-
"metadata": {},
|
1091 |
-
"outputs": [
|
1092 |
-
{
|
1093 |
-
"name": "stderr",
|
1094 |
-
"output_type": "stream",
|
1095 |
-
"text": [
|
1096 |
-
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4850/4850 [2:27:51<00:00, 1.83s/it]\n"
|
1097 |
-
]
|
1098 |
-
}
|
1099 |
-
],
|
1100 |
-
"source": [
|
1101 |
-
"encode_captioned_dataset(dataset, yfcc100m_output, batch_size=64, num_workers=16)"
|
1102 |
-
]
|
1103 |
-
},
|
1104 |
-
{
|
1105 |
-
"cell_type": "markdown",
|
1106 |
-
"id": "8953dd84",
|
1107 |
-
"metadata": {},
|
1108 |
-
"source": [
|
1109 |
-
"----"
|
1110 |
-
]
|
1111 |
-
}
|
1112 |
-
],
|
1113 |
-
"metadata": {
|
1114 |
-
"kernelspec": {
|
1115 |
-
"name": "python3",
|
1116 |
-
"display_name": "Python 3.9.0 64-bit ('Python39')"
|
1117 |
-
},
|
1118 |
-
"language_info": {
|
1119 |
-
"codemirror_mode": {
|
1120 |
-
"name": "ipython",
|
1121 |
-
"version": 3
|
1122 |
-
},
|
1123 |
-
"file_extension": ".py",
|
1124 |
-
"mimetype": "text/x-python",
|
1125 |
-
"name": "python",
|
1126 |
-
"nbconvert_exporter": "python",
|
1127 |
-
"pygments_lexer": "ipython3",
|
1128 |
-
"version": "3.9.0"
|
1129 |
-
},
|
1130 |
-
"interpreter": {
|
1131 |
-
"hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
|
1132 |
-
}
|
1133 |
-
},
|
1134 |
-
"nbformat": 4,
|
1135 |
-
"nbformat_minor": 5
|
1136 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoding/vqgan-jax-encoding.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
environment.yaml
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
name: dalle
|
2 |
-
channels:
|
3 |
-
- defaults
|
4 |
-
dependencies:
|
5 |
-
- python=3.9.5
|
6 |
-
- pip=21.1.3
|
7 |
-
- ipython=7.22.0
|
8 |
-
- cudatoolkit
|
9 |
-
- pip:
|
10 |
-
- -r requirements.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img/logo.png
ADDED
model/data-pipeline.ipynb
DELETED
@@ -1,385 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"id": "bf8fb38a",
|
6 |
-
"metadata": {},
|
7 |
-
"source": [
|
8 |
-
"# Data Pipeline"
|
9 |
-
]
|
10 |
-
},
|
11 |
-
{
|
12 |
-
"cell_type": "code",
|
13 |
-
"execution_count": 1,
|
14 |
-
"id": "9b83dcb9",
|
15 |
-
"metadata": {},
|
16 |
-
"outputs": [],
|
17 |
-
"source": [
|
18 |
-
"from dataclasses import dataclass, field\n",
|
19 |
-
"from pathlib import Path\n",
|
20 |
-
"\n",
|
21 |
-
"import datasets\n",
|
22 |
-
"from datasets import Dataset, load_dataset\n",
|
23 |
-
"import numpy as np\n",
|
24 |
-
"\n",
|
25 |
-
"from transformers import BartTokenizer\n",
|
26 |
-
"\n",
|
27 |
-
"from tqdm import tqdm\n",
|
28 |
-
"\n",
|
29 |
-
"import jax\n",
|
30 |
-
"import jax.numpy as jnp\n",
|
31 |
-
"\n",
|
32 |
-
"from flax.training.common_utils import shard"
|
33 |
-
]
|
34 |
-
},
|
35 |
-
{
|
36 |
-
"cell_type": "markdown",
|
37 |
-
"id": "a661a89e",
|
38 |
-
"metadata": {},
|
39 |
-
"source": [
|
40 |
-
"File containing image paths, captions and VQGAN-encoded indices."
|
41 |
-
]
|
42 |
-
},
|
43 |
-
{
|
44 |
-
"cell_type": "code",
|
45 |
-
"execution_count": 2,
|
46 |
-
"id": "0e84e889",
|
47 |
-
"metadata": {},
|
48 |
-
"outputs": [],
|
49 |
-
"source": [
|
50 |
-
"datafile = '/data/CC12M/images-encoded-10000.tsv' # 9999 encoded images from CC12M"
|
51 |
-
]
|
52 |
-
},
|
53 |
-
{
|
54 |
-
"cell_type": "markdown",
|
55 |
-
"id": "7fdc640b",
|
56 |
-
"metadata": {},
|
57 |
-
"source": [
|
58 |
-
"TODO: generate train/test splits if necessary."
|
59 |
-
]
|
60 |
-
},
|
61 |
-
{
|
62 |
-
"cell_type": "code",
|
63 |
-
"execution_count": 3,
|
64 |
-
"id": "cc6789b4",
|
65 |
-
"metadata": {},
|
66 |
-
"outputs": [
|
67 |
-
{
|
68 |
-
"name": "stderr",
|
69 |
-
"output_type": "stream",
|
70 |
-
"text": [
|
71 |
-
"Using custom data configuration default-91833df78e844785\n",
|
72 |
-
"Reusing dataset csv (/home/pedro/.cache/huggingface/datasets/csv/default-91833df78e844785/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23)\n"
|
73 |
-
]
|
74 |
-
}
|
75 |
-
],
|
76 |
-
"source": [
|
77 |
-
"dataset = load_dataset('csv', delimiter='\\t', data_files=[datafile])"
|
78 |
-
]
|
79 |
-
},
|
80 |
-
{
|
81 |
-
"cell_type": "code",
|
82 |
-
"execution_count": 4,
|
83 |
-
"id": "f3ed4919",
|
84 |
-
"metadata": {},
|
85 |
-
"outputs": [
|
86 |
-
{
|
87 |
-
"data": {
|
88 |
-
"text/plain": [
|
89 |
-
"DatasetDict({\n",
|
90 |
-
" train: Dataset({\n",
|
91 |
-
" features: ['image_file', 'caption', 'encoding'],\n",
|
92 |
-
" num_rows: 9999\n",
|
93 |
-
" })\n",
|
94 |
-
"})"
|
95 |
-
]
|
96 |
-
},
|
97 |
-
"execution_count": 4,
|
98 |
-
"metadata": {},
|
99 |
-
"output_type": "execute_result"
|
100 |
-
}
|
101 |
-
],
|
102 |
-
"source": [
|
103 |
-
"dataset"
|
104 |
-
]
|
105 |
-
},
|
106 |
-
{
|
107 |
-
"cell_type": "code",
|
108 |
-
"execution_count": 5,
|
109 |
-
"id": "a70c7354",
|
110 |
-
"metadata": {},
|
111 |
-
"outputs": [
|
112 |
-
{
|
113 |
-
"data": {
|
114 |
-
"text/plain": [
|
115 |
-
"Dataset({\n",
|
116 |
-
" features: ['image_file', 'caption', 'encoding'],\n",
|
117 |
-
" num_rows: 9999\n",
|
118 |
-
"})"
|
119 |
-
]
|
120 |
-
},
|
121 |
-
"execution_count": 5,
|
122 |
-
"metadata": {},
|
123 |
-
"output_type": "execute_result"
|
124 |
-
}
|
125 |
-
],
|
126 |
-
"source": [
|
127 |
-
"dataset = dataset[\"train\"]\n",
|
128 |
-
"dataset"
|
129 |
-
]
|
130 |
-
},
|
131 |
-
{
|
132 |
-
"cell_type": "markdown",
|
133 |
-
"id": "a73454cf",
|
134 |
-
"metadata": {},
|
135 |
-
"source": [
|
136 |
-
"We don't really need the `image_file` field for training. We'll drop it during pre-processing because we won't be able to numericalize it to a `jnp.array`, which would be required in JAX."
|
137 |
-
]
|
138 |
-
},
|
139 |
-
{
|
140 |
-
"cell_type": "markdown",
|
141 |
-
"id": "7c0fa992",
|
142 |
-
"metadata": {},
|
143 |
-
"source": [
|
144 |
-
"## Preprocessing"
|
145 |
-
]
|
146 |
-
},
|
147 |
-
{
|
148 |
-
"cell_type": "markdown",
|
149 |
-
"id": "a0e36582",
|
150 |
-
"metadata": {},
|
151 |
-
"source": [
|
152 |
-
"The `encoding` field contains a string representation of the encoded indices. We'll convert them to numbers. We also need to tokenize the captions."
|
153 |
-
]
|
154 |
-
},
|
155 |
-
{
|
156 |
-
"cell_type": "code",
|
157 |
-
"execution_count": 6,
|
158 |
-
"id": "d46f6ac5",
|
159 |
-
"metadata": {},
|
160 |
-
"outputs": [],
|
161 |
-
"source": [
|
162 |
-
"# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
|
163 |
-
"max_length = 256 # Read from data_args.max_source_length\n",
|
164 |
-
"tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')\n",
|
165 |
-
"image_bos = 16384 # Max token is 16383 in our VQGAN configuration"
|
166 |
-
]
|
167 |
-
},
|
168 |
-
{
|
169 |
-
"cell_type": "code",
|
170 |
-
"execution_count": 7,
|
171 |
-
"id": "4cac6643",
|
172 |
-
"metadata": {},
|
173 |
-
"outputs": [],
|
174 |
-
"source": [
|
175 |
-
"def preprocess_function(examples):\n",
|
176 |
-
" inputs = examples[\"caption\"]\n",
|
177 |
-
"# inputs = [prefix + inp for inp in inputs] # Do we need this?\n",
|
178 |
-
" model_inputs = tokenizer(\n",
|
179 |
-
" inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
|
180 |
-
" )\n",
|
181 |
-
"\n",
|
182 |
-
" model_inputs[\"labels\"] = [[image_bos] + eval(indices) for indices in examples['encoding']]\n",
|
183 |
-
"\n",
|
184 |
-
" return model_inputs"
|
185 |
-
]
|
186 |
-
},
|
187 |
-
{
|
188 |
-
"cell_type": "code",
|
189 |
-
"execution_count": 8,
|
190 |
-
"id": "e6a4cb91",
|
191 |
-
"metadata": {},
|
192 |
-
"outputs": [],
|
193 |
-
"source": [
|
194 |
-
"num_workers = 48 # We have 96 processors in the TPU\n",
|
195 |
-
"column_names = dataset.column_names\n",
|
196 |
-
"input_dataset = dataset.map(preprocess_function,\n",
|
197 |
-
" remove_columns=column_names,\n",
|
198 |
-
" batched=True,\n",
|
199 |
-
" num_proc=48\n",
|
200 |
-
")"
|
201 |
-
]
|
202 |
-
},
|
203 |
-
{
|
204 |
-
"cell_type": "code",
|
205 |
-
"execution_count": 9,
|
206 |
-
"id": "a9b1b467",
|
207 |
-
"metadata": {},
|
208 |
-
"outputs": [],
|
209 |
-
"source": [
|
210 |
-
"def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n",
|
211 |
-
" \"\"\"\n",
|
212 |
-
" Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n",
|
213 |
-
" Shuffle batches if `shuffle` is `True`.\n",
|
214 |
-
" \"\"\"\n",
|
215 |
-
" steps_per_epoch = len(dataset) // batch_size\n",
|
216 |
-
"\n",
|
217 |
-
" if shuffle:\n",
|
218 |
-
" batch_idx = jax.random.permutation(rng, len(dataset))\n",
|
219 |
-
" else:\n",
|
220 |
-
" batch_idx = jnp.arange(len(dataset))\n",
|
221 |
-
"\n",
|
222 |
-
" batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n",
|
223 |
-
" batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
|
224 |
-
"\n",
|
225 |
-
" for idx in batch_idx:\n",
|
226 |
-
" batch = dataset[idx] \n",
|
227 |
-
" batch = {k: jnp.array(v) for k, v in batch.items()}\n",
|
228 |
-
" batch = shard(batch)\n",
|
229 |
-
" yield batch"
|
230 |
-
]
|
231 |
-
},
|
232 |
-
{
|
233 |
-
"cell_type": "code",
|
234 |
-
"execution_count": 10,
|
235 |
-
"id": "0a628505",
|
236 |
-
"metadata": {},
|
237 |
-
"outputs": [
|
238 |
-
{
|
239 |
-
"name": "stderr",
|
240 |
-
"output_type": "stream",
|
241 |
-
"text": [
|
242 |
-
"INFO:absl:Starting the local TPU driver.\n",
|
243 |
-
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
244 |
-
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Host TPU Interpreter\n"
|
245 |
-
]
|
246 |
-
}
|
247 |
-
],
|
248 |
-
"source": [
|
249 |
-
"rng = jax.random.PRNGKey(23) # Use training_args.seed\n",
|
250 |
-
"batch_size = 64 # Per device\n",
|
251 |
-
"super_batch_size = batch_size * jax.device_count()"
|
252 |
-
]
|
253 |
-
},
|
254 |
-
{
|
255 |
-
"cell_type": "code",
|
256 |
-
"execution_count": 11,
|
257 |
-
"id": "b3a5ce7d",
|
258 |
-
"metadata": {},
|
259 |
-
"outputs": [],
|
260 |
-
"source": [
|
261 |
-
"loader = data_loader(rng, input_dataset, batch_size=super_batch_size)"
|
262 |
-
]
|
263 |
-
},
|
264 |
-
{
|
265 |
-
"cell_type": "code",
|
266 |
-
"execution_count": 12,
|
267 |
-
"id": "67aa8f9c",
|
268 |
-
"metadata": {},
|
269 |
-
"outputs": [],
|
270 |
-
"source": [
|
271 |
-
"superbatch = next(iter(loader))"
|
272 |
-
]
|
273 |
-
},
|
274 |
-
{
|
275 |
-
"cell_type": "code",
|
276 |
-
"execution_count": 13,
|
277 |
-
"id": "7cd99402",
|
278 |
-
"metadata": {},
|
279 |
-
"outputs": [
|
280 |
-
{
|
281 |
-
"data": {
|
282 |
-
"text/plain": [
|
283 |
-
"dict_keys(['attention_mask', 'input_ids', 'labels'])"
|
284 |
-
]
|
285 |
-
},
|
286 |
-
"execution_count": 13,
|
287 |
-
"metadata": {},
|
288 |
-
"output_type": "execute_result"
|
289 |
-
}
|
290 |
-
],
|
291 |
-
"source": [
|
292 |
-
"superbatch.keys()"
|
293 |
-
]
|
294 |
-
},
|
295 |
-
{
|
296 |
-
"cell_type": "code",
|
297 |
-
"execution_count": 14,
|
298 |
-
"id": "652a4a9e",
|
299 |
-
"metadata": {},
|
300 |
-
"outputs": [
|
301 |
-
{
|
302 |
-
"data": {
|
303 |
-
"text/plain": [
|
304 |
-
"8"
|
305 |
-
]
|
306 |
-
},
|
307 |
-
"execution_count": 14,
|
308 |
-
"metadata": {},
|
309 |
-
"output_type": "execute_result"
|
310 |
-
}
|
311 |
-
],
|
312 |
-
"source": [
|
313 |
-
"len(superbatch[\"labels\"])"
|
314 |
-
]
|
315 |
-
},
|
316 |
-
{
|
317 |
-
"cell_type": "code",
|
318 |
-
"execution_count": 15,
|
319 |
-
"id": "de7de4e8",
|
320 |
-
"metadata": {},
|
321 |
-
"outputs": [
|
322 |
-
{
|
323 |
-
"data": {
|
324 |
-
"text/plain": [
|
325 |
-
"(8, 64, 257)"
|
326 |
-
]
|
327 |
-
},
|
328 |
-
"execution_count": 15,
|
329 |
-
"metadata": {},
|
330 |
-
"output_type": "execute_result"
|
331 |
-
}
|
332 |
-
],
|
333 |
-
"source": [
|
334 |
-
"superbatch[\"labels\"].shape"
|
335 |
-
]
|
336 |
-
},
|
337 |
-
{
|
338 |
-
"cell_type": "markdown",
|
339 |
-
"id": "6800153b",
|
340 |
-
"metadata": {},
|
341 |
-
"source": [
|
342 |
-
"Any image sequence should begin with `image_bos`:"
|
343 |
-
]
|
344 |
-
},
|
345 |
-
{
|
346 |
-
"cell_type": "code",
|
347 |
-
"execution_count": 16,
|
348 |
-
"id": "cfe23a71",
|
349 |
-
"metadata": {},
|
350 |
-
"outputs": [],
|
351 |
-
"source": [
|
352 |
-
"assert superbatch[\"labels\"][1][5][0].item() == image_bos"
|
353 |
-
]
|
354 |
-
},
|
355 |
-
{
|
356 |
-
"cell_type": "code",
|
357 |
-
"execution_count": null,
|
358 |
-
"id": "0fb899b4",
|
359 |
-
"metadata": {},
|
360 |
-
"outputs": [],
|
361 |
-
"source": []
|
362 |
-
}
|
363 |
-
],
|
364 |
-
"metadata": {
|
365 |
-
"kernelspec": {
|
366 |
-
"display_name": "Python 3 (ipykernel)",
|
367 |
-
"language": "python",
|
368 |
-
"name": "python3"
|
369 |
-
},
|
370 |
-
"language_info": {
|
371 |
-
"codemirror_mode": {
|
372 |
-
"name": "ipython",
|
373 |
-
"version": 3
|
374 |
-
},
|
375 |
-
"file_extension": ".py",
|
376 |
-
"mimetype": "text/x-python",
|
377 |
-
"name": "python",
|
378 |
-
"nbconvert_exporter": "python",
|
379 |
-
"pygments_lexer": "ipython3",
|
380 |
-
"version": "3.8.10"
|
381 |
-
}
|
382 |
-
},
|
383 |
-
"nbformat": 4,
|
384 |
-
"nbformat_minor": 5
|
385 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[tool.isort]
|
2 |
+
profile = "black"
|
requirements.txt
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# Note: install with the following command:
|
2 |
-
# pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
3 |
-
# Otherwise it won't find the appropriate libtpu_nightly
|
4 |
-
requests
|
5 |
-
jax[tpu]>=0.2.16
|
6 |
-
-e git+https://github.com/huggingface/transformers.git@master#egg=transformers
|
7 |
-
-e git+https://github.com/huggingface/datasets.git@master#egg=datasets
|
8 |
-
flax
|
9 |
-
jupyter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq2seq/do_big_run.sh
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
python run_seq2seq_flax.py \
|
2 |
-
--max_source_length 128 \
|
3 |
-
--train_file /data/CC12M/encoded-small-train.tsv \ # ignored for now in our script
|
4 |
-
--validation_file /data/CC12M/encoded-small-valid.tsv \ # ignored for now in our script
|
5 |
-
--output_dir output \
|
6 |
-
--per_device_train_batch_size 56 \
|
7 |
-
--per_device_eval_batch_size 56 \
|
8 |
-
--preprocessing_num_workers 80 \
|
9 |
-
--warmup_steps 125 \
|
10 |
-
--gradient_accumulation_steps 8 \
|
11 |
-
--do_train \
|
12 |
-
--do_eval \
|
13 |
-
--adafactor \
|
14 |
-
--num_train_epochs 10 \
|
15 |
-
--log_model \
|
16 |
-
--learning_rate 0.001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq2seq/do_small_run.sh
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
python run_seq2seq_flax.py \
|
2 |
-
--max_source_length 128 \
|
3 |
-
--train_file /data/CC12M/encoded-small-train.tsv \ # ignored for now in our script
|
4 |
-
--validation_file /data/CC12M/encoded-small-valid.tsv \ # ignored for now in our script
|
5 |
-
--output_dir output \
|
6 |
-
--per_device_train_batch_size 56 \
|
7 |
-
--per_device_eval_batch_size 56 \
|
8 |
-
--preprocessing_num_workers 80 \
|
9 |
-
--warmup_steps 125 \
|
10 |
-
--gradient_accumulation_steps 8 \
|
11 |
-
--do_train \
|
12 |
-
--do_eval \
|
13 |
-
--adafactor \
|
14 |
-
--num_train_epochs 1 \
|
15 |
-
--max_train_samples 20000 \
|
16 |
-
--learning_rate 0.003
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq2seq/requirements.txt
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
datasets >= 1.1.3
|
2 |
-
jax>=0.2.8
|
3 |
-
jaxlib>=0.1.59
|
4 |
-
flax>=0.3.4
|
5 |
-
optax>=0.0.8
|
6 |
-
tensorboard
|
7 |
-
nltk
|
8 |
-
wandb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq2seq/run_seq2seq_flax.py
DELETED
@@ -1,897 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
# coding=utf-8
|
3 |
-
# Copyright 2021 The HuggingFace Team All rights reserved.
|
4 |
-
#
|
5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
-
# you may not use this file except in compliance with the License.
|
7 |
-
# You may obtain a copy of the License at
|
8 |
-
#
|
9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
-
#
|
11 |
-
# Unless required by applicable law or agreed to in writing, software
|
12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
-
# See the License for the specific language governing permissions and
|
15 |
-
# limitations under the License.
|
16 |
-
"""
|
17 |
-
Fine-tuning the library models for seq2seq, text to image.
|
18 |
-
Script adapted from run_summarization_flax.py
|
19 |
-
"""
|
20 |
-
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
21 |
-
|
22 |
-
import os
|
23 |
-
# set a common huggingface cache folder (used with datasets and transformers) and wandb cache folder (used with artifacts)
|
24 |
-
os.environ['HF_HOME'] = '/data/huggingface/' # required before importing transformers & datasets
|
25 |
-
os.environ['WANDB_CACHE_DIR'] = '/data/wandb/' # required before importing wandb
|
26 |
-
|
27 |
-
import logging as pylogging # To avoid collision with transformers.utils.logging
|
28 |
-
import sys
|
29 |
-
import time
|
30 |
-
from dataclasses import dataclass, field
|
31 |
-
from functools import partial
|
32 |
-
from pathlib import Path
|
33 |
-
from typing import Callable, Optional
|
34 |
-
|
35 |
-
import datasets
|
36 |
-
import nltk # Here to have a nice missing dependency error message early on
|
37 |
-
import numpy as np
|
38 |
-
from datasets import Dataset, load_dataset, load_metric
|
39 |
-
from tqdm import tqdm
|
40 |
-
|
41 |
-
import jax
|
42 |
-
import jax.numpy as jnp
|
43 |
-
import optax
|
44 |
-
import transformers
|
45 |
-
from filelock import FileLock
|
46 |
-
from flax import jax_utils, traverse_util
|
47 |
-
import flax.linen as nn
|
48 |
-
from flax.jax_utils import unreplicate
|
49 |
-
from flax.training import train_state
|
50 |
-
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
51 |
-
from transformers import (
|
52 |
-
CONFIG_MAPPING,
|
53 |
-
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
54 |
-
AutoConfig,
|
55 |
-
AutoTokenizer,
|
56 |
-
FlaxAutoModelForSeq2SeqLM,
|
57 |
-
FlaxBartForConditionalGeneration,
|
58 |
-
HfArgumentParser,
|
59 |
-
TrainingArguments,
|
60 |
-
)
|
61 |
-
from transformers.models.bart.modeling_flax_bart import *
|
62 |
-
from transformers.file_utils import is_offline_mode
|
63 |
-
|
64 |
-
import wandb
|
65 |
-
|
66 |
-
logger = pylogging.getLogger(__name__)
|
67 |
-
|
68 |
-
try:
|
69 |
-
nltk.data.find("tokenizers/punkt")
|
70 |
-
except (LookupError, OSError):
|
71 |
-
if is_offline_mode():
|
72 |
-
raise LookupError(
|
73 |
-
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
74 |
-
)
|
75 |
-
with FileLock(".lock") as lock:
|
76 |
-
nltk.download("punkt", quiet=True)
|
77 |
-
|
78 |
-
|
79 |
-
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
|
80 |
-
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
81 |
-
|
82 |
-
|
83 |
-
# Model hyperparameters, for convenience
|
84 |
-
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
85 |
-
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
|
86 |
-
BOS_TOKEN_ID = 16384
|
87 |
-
BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
|
88 |
-
|
89 |
-
|
90 |
-
@dataclass
|
91 |
-
class ModelArguments:
|
92 |
-
"""
|
93 |
-
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
94 |
-
"""
|
95 |
-
|
96 |
-
model_name_or_path: Optional[str] = field(
|
97 |
-
default=BASE_MODEL,
|
98 |
-
metadata={
|
99 |
-
"help": "The model checkpoint for weights initialization."
|
100 |
-
"Don't set if you want to train a model from scratch."
|
101 |
-
},
|
102 |
-
)
|
103 |
-
model_type: Optional[str] = field(
|
104 |
-
default=None,
|
105 |
-
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
106 |
-
)
|
107 |
-
config_name: Optional[str] = field(
|
108 |
-
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
109 |
-
)
|
110 |
-
tokenizer_name: Optional[str] = field(
|
111 |
-
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
112 |
-
)
|
113 |
-
cache_dir: Optional[str] = field(
|
114 |
-
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
115 |
-
)
|
116 |
-
use_fast_tokenizer: bool = field(
|
117 |
-
default=True,
|
118 |
-
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
119 |
-
)
|
120 |
-
dtype: Optional[str] = field(
|
121 |
-
default="float32",
|
122 |
-
metadata={
|
123 |
-
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
124 |
-
},
|
125 |
-
)
|
126 |
-
|
127 |
-
|
128 |
-
@dataclass
|
129 |
-
class DataTrainingArguments:
|
130 |
-
"""
|
131 |
-
Arguments pertaining to what data we are going to input our model for training and eval.
|
132 |
-
"""
|
133 |
-
|
134 |
-
dataset_name: Optional[str] = field(
|
135 |
-
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
136 |
-
)
|
137 |
-
dataset_config_name: Optional[str] = field(
|
138 |
-
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
139 |
-
)
|
140 |
-
text_column: Optional[str] = field(
|
141 |
-
default='caption',
|
142 |
-
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
|
143 |
-
)
|
144 |
-
encoding_column: Optional[str] = field(
|
145 |
-
default='encoding',
|
146 |
-
metadata={"help": "The name of the column in the datasets containing the image encodings."},
|
147 |
-
)
|
148 |
-
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
149 |
-
validation_file: Optional[str] = field(
|
150 |
-
default=None,
|
151 |
-
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
152 |
-
)
|
153 |
-
test_file: Optional[str] = field(
|
154 |
-
default=None,
|
155 |
-
metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
|
156 |
-
)
|
157 |
-
max_source_length: Optional[int] = field(
|
158 |
-
default=128,
|
159 |
-
metadata={
|
160 |
-
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
161 |
-
"than this will be truncated, sequences shorter will be padded."
|
162 |
-
},
|
163 |
-
)
|
164 |
-
no_decay: bool = field(
|
165 |
-
default=False, metadata={"help": "Whether to use decay in the learning rate scheduler."}
|
166 |
-
)
|
167 |
-
max_target_length: Optional[int] = field(
|
168 |
-
default=OUTPUT_LENGTH,
|
169 |
-
metadata={
|
170 |
-
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
171 |
-
"than this will be truncated, sequences shorter will be padded."
|
172 |
-
},
|
173 |
-
)
|
174 |
-
val_max_target_length: Optional[int] = field(
|
175 |
-
default=OUTPUT_LENGTH,
|
176 |
-
metadata={
|
177 |
-
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
178 |
-
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
179 |
-
"This argument is also used to override the `max_length` param of `model.generate`, which is used "
|
180 |
-
"during evaluation."
|
181 |
-
},
|
182 |
-
)
|
183 |
-
max_train_samples: Optional[int] = field(
|
184 |
-
default=None,
|
185 |
-
metadata={
|
186 |
-
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
187 |
-
"value if set."
|
188 |
-
},
|
189 |
-
)
|
190 |
-
max_eval_samples: Optional[int] = field(
|
191 |
-
default=None,
|
192 |
-
metadata={
|
193 |
-
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
194 |
-
"value if set."
|
195 |
-
},
|
196 |
-
)
|
197 |
-
max_predict_samples: Optional[int] = field(
|
198 |
-
default=None,
|
199 |
-
metadata={
|
200 |
-
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
201 |
-
"value if set."
|
202 |
-
},
|
203 |
-
)
|
204 |
-
preprocessing_num_workers: Optional[int] = field(
|
205 |
-
default=80, # ensure we have the same datasets cached data and avoid using too much space
|
206 |
-
metadata={"help": "The number of processes to use for the preprocessing."},
|
207 |
-
)
|
208 |
-
source_prefix: Optional[str] = field(
|
209 |
-
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
210 |
-
)
|
211 |
-
predict_with_generate: bool = field(
|
212 |
-
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
213 |
-
)
|
214 |
-
num_beams: Optional[int] = field(
|
215 |
-
default=None,
|
216 |
-
metadata={
|
217 |
-
"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
|
218 |
-
"which is used during evaluation."
|
219 |
-
},
|
220 |
-
)
|
221 |
-
overwrite_cache: bool = field(
|
222 |
-
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
223 |
-
)
|
224 |
-
log_interval: Optional[int] = field(
|
225 |
-
default=40,
|
226 |
-
metadata={
|
227 |
-
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
228 |
-
"value if set."
|
229 |
-
},
|
230 |
-
)
|
231 |
-
log_model: bool = field(
|
232 |
-
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
233 |
-
)
|
234 |
-
save_model_steps: Optional[int] = field(
|
235 |
-
default=3000, # about once every hour in our experiments
|
236 |
-
metadata={
|
237 |
-
"help": "For logging the model more frequently. Used only when `log_model` is set."
|
238 |
-
},
|
239 |
-
)
|
240 |
-
|
241 |
-
def __post_init__(self):
|
242 |
-
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
243 |
-
raise ValueError("Need either a dataset name or a training/validation file.")
|
244 |
-
else:
|
245 |
-
if self.train_file is not None:
|
246 |
-
extension = self.train_file.split(".")[-1]
|
247 |
-
assert extension in ["tsv", "csv", "json"], "`train_file` should be a tsv, csv or json file."
|
248 |
-
if self.validation_file is not None:
|
249 |
-
extension = self.validation_file.split(".")[-1]
|
250 |
-
assert extension in ["tsv", "csv", "json"], "`validation_file` should be a tsv, csv or json file."
|
251 |
-
if self.val_max_target_length is None:
|
252 |
-
self.val_max_target_length = self.max_target_length
|
253 |
-
|
254 |
-
|
255 |
-
class TrainState(train_state.TrainState):
|
256 |
-
dropout_rng: jnp.ndarray
|
257 |
-
grad_accum: jnp.ndarray
|
258 |
-
optimizer_step: int
|
259 |
-
|
260 |
-
def replicate(self):
|
261 |
-
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
262 |
-
|
263 |
-
|
264 |
-
class CustomFlaxBartModule(FlaxBartModule):
|
265 |
-
def setup(self):
|
266 |
-
# we keep shared to easily load pre-trained weights
|
267 |
-
self.shared = nn.Embed(
|
268 |
-
self.config.vocab_size,
|
269 |
-
self.config.d_model,
|
270 |
-
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
271 |
-
dtype=self.dtype,
|
272 |
-
)
|
273 |
-
# a separate embedding is used for the decoder
|
274 |
-
self.decoder_embed = nn.Embed(
|
275 |
-
OUTPUT_VOCAB_SIZE,
|
276 |
-
self.config.d_model,
|
277 |
-
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
278 |
-
dtype=self.dtype,
|
279 |
-
)
|
280 |
-
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
281 |
-
|
282 |
-
# the decoder has a different config
|
283 |
-
decoder_config = BartConfig(self.config.to_dict())
|
284 |
-
decoder_config.max_position_embeddings = OUTPUT_LENGTH
|
285 |
-
decoder_config.min_length = OUTPUT_LENGTH
|
286 |
-
decoder_config.max_length = OUTPUT_LENGTH
|
287 |
-
decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
|
288 |
-
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
289 |
-
|
290 |
-
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
291 |
-
def setup(self):
|
292 |
-
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
293 |
-
self.lm_head = nn.Dense(
|
294 |
-
OUTPUT_VOCAB_SIZE,
|
295 |
-
use_bias=False,
|
296 |
-
dtype=self.dtype,
|
297 |
-
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
298 |
-
)
|
299 |
-
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
|
300 |
-
|
301 |
-
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
302 |
-
module_class = CustomFlaxBartForConditionalGenerationModule
|
303 |
-
|
304 |
-
|
305 |
-
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
|
306 |
-
"""
|
307 |
-
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
308 |
-
Shuffle batches if `shuffle` is `True`.
|
309 |
-
"""
|
310 |
-
steps_per_epoch = len(dataset) // batch_size
|
311 |
-
|
312 |
-
if shuffle:
|
313 |
-
batch_idx = jax.random.permutation(rng, len(dataset))
|
314 |
-
else:
|
315 |
-
batch_idx = jnp.arange(len(dataset))
|
316 |
-
|
317 |
-
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
318 |
-
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
319 |
-
|
320 |
-
for idx in batch_idx:
|
321 |
-
batch = dataset[idx]
|
322 |
-
batch = {k: jnp.array(v) for k, v in batch.items()}
|
323 |
-
|
324 |
-
batch = shard(batch)
|
325 |
-
|
326 |
-
yield batch
|
327 |
-
|
328 |
-
|
329 |
-
def create_learning_rate_fn(
|
330 |
-
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
|
331 |
-
) -> Callable[[int], jnp.array]:
|
332 |
-
"""Returns a linear warmup, linear_decay learning rate function."""
|
333 |
-
steps_per_epoch = train_ds_size // train_batch_size
|
334 |
-
num_train_steps = steps_per_epoch * num_train_epochs
|
335 |
-
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
336 |
-
if no_decay:
|
337 |
-
return warmup_fn
|
338 |
-
decay_fn = optax.linear_schedule(
|
339 |
-
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
340 |
-
)
|
341 |
-
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
342 |
-
return schedule_fn
|
343 |
-
|
344 |
-
|
345 |
-
def wandb_log(metrics, step=None, prefix=None):
|
346 |
-
if jax.process_index() == 0:
|
347 |
-
log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
348 |
-
if step is not None:
|
349 |
-
log_metrics['train/step'] = step
|
350 |
-
wandb.log(log_metrics)
|
351 |
-
|
352 |
-
|
353 |
-
def main():
|
354 |
-
# See all possible arguments in src/transformers/training_args.py
|
355 |
-
# or by passing the --help flag to this script.
|
356 |
-
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
357 |
-
|
358 |
-
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
359 |
-
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
360 |
-
# If we pass only one argument to the script and it's the path to a json file,
|
361 |
-
# let's parse it to get our arguments.
|
362 |
-
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
363 |
-
else:
|
364 |
-
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
365 |
-
|
366 |
-
logger.warning(f"eval_steps has been manually hardcoded") # TODO: remove it later, convenient for now
|
367 |
-
training_args.eval_steps = 400
|
368 |
-
|
369 |
-
if (
|
370 |
-
os.path.exists(training_args.output_dir)
|
371 |
-
and os.listdir(training_args.output_dir)
|
372 |
-
and training_args.do_train
|
373 |
-
and not training_args.overwrite_output_dir
|
374 |
-
):
|
375 |
-
raise ValueError(
|
376 |
-
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
377 |
-
"Use --overwrite_output_dir to overcome."
|
378 |
-
)
|
379 |
-
|
380 |
-
# Set up wandb run
|
381 |
-
wandb.init(
|
382 |
-
entity='wandb',
|
383 |
-
project='hf-flax-dalle-mini',
|
384 |
-
job_type='Seq2SeqVQGAN',
|
385 |
-
config=parser.parse_args()
|
386 |
-
)
|
387 |
-
|
388 |
-
# set default x-axis as 'train/step'
|
389 |
-
wandb.define_metric('train/step')
|
390 |
-
wandb.define_metric('*', step_metric='train/step')
|
391 |
-
|
392 |
-
# Make one log on every process with the configuration for debugging.
|
393 |
-
pylogging.basicConfig(
|
394 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
395 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
396 |
-
level=pylogging.INFO,
|
397 |
-
)
|
398 |
-
# Setup logging, we only want one process per machine to log things on the screen.
|
399 |
-
logger.setLevel(pylogging.INFO if jax.process_index() == 0 else pylogging.ERROR)
|
400 |
-
if jax.process_index() == 0:
|
401 |
-
datasets.utils.logging.set_verbosity_warning()
|
402 |
-
transformers.utils.logging.set_verbosity_info()
|
403 |
-
else:
|
404 |
-
datasets.utils.logging.set_verbosity_error()
|
405 |
-
transformers.utils.logging.set_verbosity_error()
|
406 |
-
|
407 |
-
# Set the verbosity to info of the Transformers logger (on main process only):
|
408 |
-
logger.info(f"Training/evaluation parameters {training_args}")
|
409 |
-
|
410 |
-
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
411 |
-
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
412 |
-
# (the dataset will be downloaded automatically from the datasets Hub).
|
413 |
-
#
|
414 |
-
data_files = {}
|
415 |
-
logger.warning(f"Datasets path have been manually hardcoded") # TODO: remove it later, convenient for now
|
416 |
-
if data_args.train_file is not None:
|
417 |
-
data_files["train"] = ["/data/CC3M/training-encoded.tsv", "/data/CC12M/encoded-train.tsv"]
|
418 |
-
if data_args.validation_file is not None:
|
419 |
-
data_files["validation"] = ["/data/CC3M/validation-encoded.tsv"]
|
420 |
-
if data_args.test_file is not None:
|
421 |
-
data_files["test"] = data_args.test_file
|
422 |
-
dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
|
423 |
-
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
424 |
-
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
425 |
-
|
426 |
-
# Load pretrained model and tokenizer
|
427 |
-
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
428 |
-
model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
429 |
-
)
|
430 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
431 |
-
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
432 |
-
)
|
433 |
-
|
434 |
-
# Set up our new model config
|
435 |
-
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
436 |
-
config.tie_word_embeddings = False
|
437 |
-
config.decoder_start_token_id = BOS_TOKEN_ID
|
438 |
-
config.bos_token_id = BOS_TOKEN_ID # should not be used
|
439 |
-
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
440 |
-
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
441 |
-
config.forced_bos_token_id = None # we don't need this token
|
442 |
-
config.forced_eos_token_id = None # we don't need this token
|
443 |
-
#config.min_length = data_args.max_target_length # Set only in decoder?
|
444 |
-
#config.max_length = data_args.max_target_length # Set only in decoder?
|
445 |
-
|
446 |
-
print(f"TPUs: {jax.device_count()}")
|
447 |
-
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
448 |
-
|
449 |
-
# Create a custom model and initialize it randomly
|
450 |
-
model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
451 |
-
|
452 |
-
# Use pre-trained weights for encoder
|
453 |
-
model.params['model']['encoder'] = base_model.params['model']['encoder']
|
454 |
-
model.params['model']['shared'] = base_model.params['model']['shared']
|
455 |
-
del base_model
|
456 |
-
|
457 |
-
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
458 |
-
|
459 |
-
# Preprocessing the datasets.
|
460 |
-
# We need to tokenize inputs and targets.
|
461 |
-
if training_args.do_train:
|
462 |
-
column_names = dataset["train"].column_names
|
463 |
-
elif training_args.do_eval:
|
464 |
-
column_names = dataset["validation"].column_names
|
465 |
-
elif training_args.do_predict:
|
466 |
-
column_names = dataset["test"].column_names
|
467 |
-
else:
|
468 |
-
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
469 |
-
return
|
470 |
-
|
471 |
-
# Get the column names for input/target.
|
472 |
-
text_column = data_args.text_column
|
473 |
-
encoding_column = data_args.encoding_column
|
474 |
-
|
475 |
-
# Temporarily set max_target_length for training.
|
476 |
-
max_target_length = data_args.max_target_length
|
477 |
-
|
478 |
-
def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
|
479 |
-
"""
|
480 |
-
Shift input ids one token to the right.
|
481 |
-
"""
|
482 |
-
shifted_input_ids = np.zeros(input_ids.shape)
|
483 |
-
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
484 |
-
shifted_input_ids[:, 0] = decoder_start_token_id
|
485 |
-
return shifted_input_ids
|
486 |
-
|
487 |
-
def preprocess_function(examples):
|
488 |
-
inputs = examples[text_column]
|
489 |
-
inputs = [prefix + inp for inp in inputs]
|
490 |
-
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
491 |
-
model_inputs = tokenizer(
|
492 |
-
inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
|
493 |
-
)
|
494 |
-
|
495 |
-
# set up targets
|
496 |
-
# Note: labels correspond to our target indices
|
497 |
-
# decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
|
498 |
-
labels = [eval(indices) for indices in examples['encoding']]
|
499 |
-
labels = np.asarray(labels)
|
500 |
-
|
501 |
-
# We need the labels, in addition to the decoder_input_ids, for the compute_loss function
|
502 |
-
model_inputs["labels"] = labels
|
503 |
-
|
504 |
-
# In our case, this prepends the bos token and removes the last one
|
505 |
-
decoder_input_ids = shift_tokens_right(labels, config.decoder_start_token_id)
|
506 |
-
model_inputs["decoder_input_ids"] = decoder_input_ids
|
507 |
-
|
508 |
-
return model_inputs
|
509 |
-
|
510 |
-
if training_args.do_train:
|
511 |
-
if "train" not in dataset:
|
512 |
-
raise ValueError("--do_train requires a train dataset")
|
513 |
-
train_dataset = dataset["train"]
|
514 |
-
if data_args.max_train_samples is not None:
|
515 |
-
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
516 |
-
train_dataset = train_dataset.map(
|
517 |
-
preprocess_function,
|
518 |
-
batched=True,
|
519 |
-
num_proc=data_args.preprocessing_num_workers,
|
520 |
-
remove_columns=column_names,
|
521 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
522 |
-
desc="Running tokenizer on train dataset",
|
523 |
-
)
|
524 |
-
|
525 |
-
if training_args.do_eval:
|
526 |
-
max_target_length = data_args.val_max_target_length
|
527 |
-
if "validation" not in dataset:
|
528 |
-
raise ValueError("--do_eval requires a validation dataset")
|
529 |
-
eval_dataset = dataset["validation"]
|
530 |
-
if data_args.max_eval_samples is not None:
|
531 |
-
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
532 |
-
eval_dataset = eval_dataset.map(
|
533 |
-
preprocess_function,
|
534 |
-
batched=True,
|
535 |
-
num_proc=data_args.preprocessing_num_workers,
|
536 |
-
remove_columns=column_names,
|
537 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
538 |
-
desc="Running tokenizer on validation dataset",
|
539 |
-
)
|
540 |
-
|
541 |
-
if training_args.do_predict:
|
542 |
-
max_target_length = data_args.val_max_target_length
|
543 |
-
if "test" not in dataset:
|
544 |
-
raise ValueError("--do_predict requires a test dataset")
|
545 |
-
predict_dataset = dataset["test"]
|
546 |
-
if data_args.max_predict_samples is not None:
|
547 |
-
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
548 |
-
predict_dataset = predict_dataset.map(
|
549 |
-
preprocess_function,
|
550 |
-
batched=True,
|
551 |
-
num_proc=data_args.preprocessing_num_workers,
|
552 |
-
remove_columns=column_names,
|
553 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
554 |
-
desc="Running tokenizer on prediction dataset",
|
555 |
-
)
|
556 |
-
|
557 |
-
# Metric
|
558 |
-
#metric = load_metric("rouge")
|
559 |
-
|
560 |
-
def postprocess_text(preds, labels):
|
561 |
-
preds = [pred.strip() for pred in preds]
|
562 |
-
labels = [label.strip() for label in labels]
|
563 |
-
|
564 |
-
# rougeLSum expects newline after each sentence
|
565 |
-
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
566 |
-
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
567 |
-
|
568 |
-
return preds, labels
|
569 |
-
|
570 |
-
def compute_metrics(preds, labels):
|
571 |
-
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
572 |
-
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
573 |
-
|
574 |
-
# Some simple post-processing
|
575 |
-
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
576 |
-
|
577 |
-
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
578 |
-
# Extract a few results from ROUGE
|
579 |
-
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
580 |
-
|
581 |
-
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
582 |
-
result["gen_len"] = np.mean(prediction_lens)
|
583 |
-
result = {k: round(v, 4) for k, v in result.items()}
|
584 |
-
return result
|
585 |
-
|
586 |
-
# Initialize our training
|
587 |
-
rng = jax.random.PRNGKey(training_args.seed)
|
588 |
-
rng, dropout_rng = jax.random.split(rng)
|
589 |
-
|
590 |
-
# Store some constant
|
591 |
-
num_epochs = int(training_args.num_train_epochs)
|
592 |
-
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
593 |
-
total_batch_size = int(train_batch_size) * training_args.gradient_accumulation_steps
|
594 |
-
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
595 |
-
steps_per_epoch = len(train_dataset) // train_batch_size
|
596 |
-
total_steps = steps_per_epoch * num_epochs
|
597 |
-
total_optimization_steps = (len(train_dataset) // total_batch_size) * num_epochs
|
598 |
-
|
599 |
-
# Create learning rate schedule
|
600 |
-
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
601 |
-
len(train_dataset),
|
602 |
-
total_batch_size,
|
603 |
-
training_args.num_train_epochs,
|
604 |
-
training_args.warmup_steps,
|
605 |
-
training_args.learning_rate,
|
606 |
-
data_args.no_decay
|
607 |
-
)
|
608 |
-
|
609 |
-
# We use Optax's "masking" functionality to not apply weight decay
|
610 |
-
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
611 |
-
# mask boolean with the same structure as the parameters.
|
612 |
-
# The mask is True for parameters that should be decayed.
|
613 |
-
# Note that this mask is specifically adapted for FlaxBart.
|
614 |
-
# For FlaxT5, one should correct the layer norm parameter naming
|
615 |
-
# accordingly - see `run_t5_mlm_flax.py` e.g.
|
616 |
-
def decay_mask_fn(params):
|
617 |
-
flat_params = traverse_util.flatten_dict(params)
|
618 |
-
layer_norm_params = [
|
619 |
-
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
|
620 |
-
]
|
621 |
-
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
|
622 |
-
return traverse_util.unflatten_dict(flat_mask)
|
623 |
-
|
624 |
-
# create adam optimizer
|
625 |
-
if training_args.adafactor:
|
626 |
-
# We use the default parameters here to initialize adafactor,
|
627 |
-
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
628 |
-
optimizer = optax.adafactor(
|
629 |
-
learning_rate=linear_decay_lr_schedule_fn,
|
630 |
-
)
|
631 |
-
else:
|
632 |
-
optimizer = optax.adamw(
|
633 |
-
learning_rate=linear_decay_lr_schedule_fn,
|
634 |
-
b1=training_args.adam_beta1,
|
635 |
-
b2=training_args.adam_beta2,
|
636 |
-
eps=training_args.adam_epsilon,
|
637 |
-
weight_decay=training_args.weight_decay,
|
638 |
-
mask=decay_mask_fn,
|
639 |
-
)
|
640 |
-
|
641 |
-
# Setup train state
|
642 |
-
state = TrainState.create(
|
643 |
-
apply_fn=model.__call__,
|
644 |
-
params=model.params,
|
645 |
-
tx=optimizer,
|
646 |
-
dropout_rng=dropout_rng,
|
647 |
-
grad_accum=jax.tree_map(jnp.zeros_like, model.params),
|
648 |
-
optimizer_step=0,
|
649 |
-
)
|
650 |
-
|
651 |
-
# label smoothed cross entropy
|
652 |
-
def loss_fn(logits, labels):
|
653 |
-
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
654 |
-
loss = loss.mean()
|
655 |
-
return loss
|
656 |
-
|
657 |
-
# Define gradient update step fn
|
658 |
-
def train_step(state, batch):
|
659 |
-
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
660 |
-
|
661 |
-
def compute_loss(params):
|
662 |
-
labels = batch.pop("labels")
|
663 |
-
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
664 |
-
loss = loss_fn(logits, labels)
|
665 |
-
return loss
|
666 |
-
|
667 |
-
grad_fn = jax.value_and_grad(compute_loss)
|
668 |
-
loss, grads = grad_fn(state.params)
|
669 |
-
grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
|
670 |
-
|
671 |
-
def update_fn():
|
672 |
-
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
673 |
-
grads = jax.lax.pmean(grads, "batch")
|
674 |
-
new_state = state.apply_gradients(
|
675 |
-
grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step + 1
|
676 |
-
)
|
677 |
-
return new_state
|
678 |
-
|
679 |
-
new_state = jax.lax.cond(
|
680 |
-
(state.step + 1) % training_args.gradient_accumulation_steps == 0,
|
681 |
-
lambda _: update_fn(),
|
682 |
-
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
683 |
-
None,
|
684 |
-
)
|
685 |
-
|
686 |
-
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.optimizer_step)}
|
687 |
-
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
688 |
-
|
689 |
-
return new_state.replace(dropout_rng=new_dropout_rng), metrics
|
690 |
-
|
691 |
-
# Define eval fn
|
692 |
-
def eval_step(params, batch):
|
693 |
-
labels = batch.pop("labels")
|
694 |
-
logits = model(**batch, params=params, train=False)[0]
|
695 |
-
loss = loss_fn(logits, labels)
|
696 |
-
|
697 |
-
# summarize metrics
|
698 |
-
metrics = {"loss": loss}
|
699 |
-
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
700 |
-
return metrics
|
701 |
-
|
702 |
-
# Define generation function
|
703 |
-
max_length = (
|
704 |
-
data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
|
705 |
-
)
|
706 |
-
num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
|
707 |
-
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
708 |
-
|
709 |
-
def generate_step(params, batch):
|
710 |
-
model.params = params
|
711 |
-
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
712 |
-
return output_ids.sequences
|
713 |
-
|
714 |
-
# Create parallel version of the train and eval step
|
715 |
-
p_train_step = jax.pmap(
|
716 |
-
train_step, "batch", donate_argnums=(0,)
|
717 |
-
)
|
718 |
-
p_eval_step = jax.pmap(eval_step, "batch")
|
719 |
-
p_generate_step = jax.pmap(generate_step, "batch")
|
720 |
-
|
721 |
-
# Replicate the train state on each device
|
722 |
-
state = state.replicate()
|
723 |
-
|
724 |
-
logger.info("***** Running training *****")
|
725 |
-
logger.info(f" Num examples = {len(train_dataset)}")
|
726 |
-
logger.info(f" Num Epochs = {num_epochs}")
|
727 |
-
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
728 |
-
logger.info(
|
729 |
-
f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
|
730 |
-
)
|
731 |
-
logger.info(f" Total global steps = {total_steps}")
|
732 |
-
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
733 |
-
|
734 |
-
train_time = 0
|
735 |
-
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
736 |
-
global_step = 0
|
737 |
-
|
738 |
-
def run_evaluation():
|
739 |
-
# ======================== Evaluating ==============================
|
740 |
-
eval_metrics = []
|
741 |
-
if training_args.do_eval:
|
742 |
-
eval_preds = []
|
743 |
-
eval_labels = []
|
744 |
-
|
745 |
-
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
746 |
-
eval_steps = len(eval_dataset) // eval_batch_size
|
747 |
-
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
748 |
-
# Model forward
|
749 |
-
batch = next(eval_loader)
|
750 |
-
labels = batch["labels"]
|
751 |
-
|
752 |
-
metrics = p_eval_step(state.params, batch)
|
753 |
-
eval_metrics.append(metrics)
|
754 |
-
|
755 |
-
# generation
|
756 |
-
if data_args.predict_with_generate:
|
757 |
-
generated_ids = p_generate_step(state.params, batch)
|
758 |
-
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
759 |
-
eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
760 |
-
|
761 |
-
# normalize eval metrics
|
762 |
-
eval_metrics = get_metrics(eval_metrics)
|
763 |
-
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
764 |
-
|
765 |
-
# log metrics
|
766 |
-
wandb_log(eval_metrics, step=global_step, prefix='eval')
|
767 |
-
|
768 |
-
# compute ROUGE metrics
|
769 |
-
rouge_desc = ""
|
770 |
-
# if data_args.predict_with_generate:
|
771 |
-
# rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
772 |
-
# eval_metrics.update(rouge_metrics)
|
773 |
-
# rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
774 |
-
|
775 |
-
# Print metrics and update progress bar
|
776 |
-
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
777 |
-
epochs.write(desc)
|
778 |
-
epochs.desc = desc
|
779 |
-
|
780 |
-
return eval_metrics
|
781 |
-
|
782 |
-
def run_save_model(step, epoch, eval_metrics=None):
|
783 |
-
if jax.process_index() == 0:
|
784 |
-
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
785 |
-
|
786 |
-
# save model locally
|
787 |
-
model.save_pretrained(
|
788 |
-
training_args.output_dir,
|
789 |
-
params=params,
|
790 |
-
)
|
791 |
-
|
792 |
-
# save to W&B
|
793 |
-
if data_args.log_model:
|
794 |
-
metadata = {'step': step, 'epoch': epoch}
|
795 |
-
if eval_metrics is not None:
|
796 |
-
metadata['eval/loss'] = eval_metrics['loss']
|
797 |
-
artifact = wandb.Artifact(
|
798 |
-
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
799 |
-
)
|
800 |
-
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
801 |
-
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
802 |
-
wandb.run.log_artifact(artifact)
|
803 |
-
|
804 |
-
# save to the hub
|
805 |
-
if training_args.push_to_hub:
|
806 |
-
model.save_pretrained(
|
807 |
-
training_args.output_dir,
|
808 |
-
params=params,
|
809 |
-
push_to_hub=training_args.push_to_hub,
|
810 |
-
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
811 |
-
temp_dir=True # avoid issues with being in a repository
|
812 |
-
)
|
813 |
-
|
814 |
-
for epoch in epochs:
|
815 |
-
# ======================== Training ================================
|
816 |
-
train_start = time.time()
|
817 |
-
|
818 |
-
# Create sampling rng
|
819 |
-
rng, input_rng = jax.random.split(rng)
|
820 |
-
|
821 |
-
# Generate an epoch by shuffling sampling indices from the train dataset
|
822 |
-
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
823 |
-
steps_per_epoch = len(train_dataset) // train_batch_size
|
824 |
-
# train
|
825 |
-
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
826 |
-
global_step +=1
|
827 |
-
batch = next(train_loader)
|
828 |
-
state, train_metric = p_train_step(state, batch)
|
829 |
-
|
830 |
-
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
831 |
-
# log metrics
|
832 |
-
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
833 |
-
|
834 |
-
if global_step % training_args.eval_steps == 0:
|
835 |
-
run_evaluation()
|
836 |
-
|
837 |
-
if global_step % data_args.save_model_steps == 0:
|
838 |
-
run_save_model(global_step, epoch)
|
839 |
-
|
840 |
-
# log final train metrics
|
841 |
-
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
842 |
-
|
843 |
-
train_time += time.time() - train_start
|
844 |
-
train_metric = unreplicate(train_metric)
|
845 |
-
epochs.write(
|
846 |
-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
847 |
-
)
|
848 |
-
|
849 |
-
# Final evaluation
|
850 |
-
eval_metrics = run_evaluation()
|
851 |
-
|
852 |
-
# save checkpoint after each epoch and push checkpoint to the hub
|
853 |
-
run_save_model(global_step, epoch, eval_metrics)
|
854 |
-
|
855 |
-
|
856 |
-
# ======================== Prediction loop ==============================
|
857 |
-
if training_args.do_predict:
|
858 |
-
logger.info("*** Predict ***")
|
859 |
-
|
860 |
-
pred_metrics = []
|
861 |
-
pred_generations = []
|
862 |
-
pred_labels = []
|
863 |
-
|
864 |
-
pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
|
865 |
-
pred_steps = len(predict_dataset) // eval_batch_size
|
866 |
-
for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
|
867 |
-
# Model forward
|
868 |
-
batch = next(pred_loader)
|
869 |
-
labels = batch["labels"]
|
870 |
-
|
871 |
-
metrics = p_eval_step(state.params, batch)
|
872 |
-
pred_metrics.append(metrics)
|
873 |
-
|
874 |
-
# generation
|
875 |
-
if data_args.predict_with_generate:
|
876 |
-
generated_ids = p_generate_step(state.params, batch)
|
877 |
-
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
878 |
-
pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
879 |
-
|
880 |
-
# normalize prediction metrics
|
881 |
-
pred_metrics = get_metrics(pred_metrics)
|
882 |
-
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
883 |
-
|
884 |
-
# compute ROUGE metrics
|
885 |
-
rouge_desc = ""
|
886 |
-
if data_args.predict_with_generate:
|
887 |
-
rouge_metrics = compute_metrics(pred_generations, pred_labels)
|
888 |
-
pred_metrics.update(rouge_metrics)
|
889 |
-
rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
|
890 |
-
|
891 |
-
# Print metrics
|
892 |
-
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
|
893 |
-
logger.info(desc)
|
894 |
-
|
895 |
-
|
896 |
-
if __name__ == "__main__":
|
897 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setup.cfg
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[metadata]
|
2 |
+
name = dalle_mini
|
3 |
+
version = attr: dalle_mini.__version__
|
4 |
+
description = DALL·E mini - Generate images from a text prompt
|
5 |
+
long_description = file: README.md
|
6 |
+
long_description_content_type = text/markdown
|
7 |
+
url = https://github.com/borisdayma/dalle-mini
|
8 |
+
project_urls =
|
9 |
+
Bug Tracker = https://github.com/borisdayma/dalle-mini/issues
|
10 |
+
|
11 |
+
[options]
|
12 |
+
packages = find:
|
13 |
+
install_requires =
|
14 |
+
transformers
|
15 |
+
unidecode
|
16 |
+
ftfy
|
17 |
+
pillow
|
18 |
+
jax
|
19 |
+
flax
|
20 |
+
|
21 |
+
[options.extras_require]
|
22 |
+
dev =
|
23 |
+
tqdm
|
24 |
+
wandb
|
25 |
+
optax
|
26 |
+
black[jupyter]
|
27 |
+
isort
|
setup.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
setup()
|
tools/dataset/encode_dataset.ipynb
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "d0b72877",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Pre-encoding a dataset for DALLE·mini"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "markdown",
|
13 |
+
"id": "ba7b31e6",
|
14 |
+
"metadata": {},
|
15 |
+
"source": [
|
16 |
+
"This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
|
17 |
+
"\n",
|
18 |
+
"Adapt it to your own dataset and image encoder.\n",
|
19 |
+
"\n",
|
20 |
+
"At the end you should have a dataset of pairs:\n",
|
21 |
+
"* a caption defined as a string\n",
|
22 |
+
"* an encoded image defined as a list of int."
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": null,
|
28 |
+
"id": "3b59489e",
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"from tqdm.notebook import tqdm\n",
|
33 |
+
"\n",
|
34 |
+
"import torchvision.transforms as T\n",
|
35 |
+
"\n",
|
36 |
+
"import webdataset as wds\n",
|
37 |
+
"\n",
|
38 |
+
"import jax\n",
|
39 |
+
"import braceexpand\n",
|
40 |
+
"from pathlib import Path"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"id": "c7c4c1e6",
|
46 |
+
"metadata": {},
|
47 |
+
"source": [
|
48 |
+
"## Configuration Parameters"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": 3,
|
54 |
+
"id": "1265dbfe",
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [],
|
57 |
+
"source": [
|
58 |
+
"shards = \"my_images/shard-{0000..0008}.tar\" # defined using braceexpand format as used by webdataset\n",
|
59 |
+
"encoded_output = Path(\"encoded_data\") # where we will save our encoded data\n",
|
60 |
+
"\n",
|
61 |
+
"VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
|
62 |
+
" \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
|
63 |
+
" \"85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384\",\n",
|
64 |
+
")\n",
|
65 |
+
"\n",
|
66 |
+
"# good defaults for a TPU v3-8\n",
|
67 |
+
"batch_size = 128 # Per device\n",
|
68 |
+
"num_workers = 8 # For parallel processing\n",
|
69 |
+
"total_bs = batch_size * jax.device_count() # You can use a smaller size while testing\n",
|
70 |
+
"save_frequency = 128 # Number of batches to create a new file (180MB for f16 and 720MB for f8 per file)"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": 5,
|
76 |
+
"id": "cd956ec6-7d98-4d4d-a454-f80fe857eadd",
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [
|
79 |
+
{
|
80 |
+
"data": {
|
81 |
+
"text/plain": [
|
82 |
+
"['XXX/shard-0000.tar',\n",
|
83 |
+
" 'XXX/shard-0001.tar',\n",
|
84 |
+
" 'XXX/shard-0002.tar',\n",
|
85 |
+
" 'XXX/shard-0003.tar',\n",
|
86 |
+
" 'XXX/shard-0004.tar',\n",
|
87 |
+
" 'XXX/shard-0005.tar',\n",
|
88 |
+
" 'XXX/shard-0006.tar',\n",
|
89 |
+
" 'XXX/shard-0007.tar',\n",
|
90 |
+
" 'XXX/shard-0008.tar']"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
"execution_count": 5,
|
94 |
+
"metadata": {},
|
95 |
+
"output_type": "execute_result"
|
96 |
+
}
|
97 |
+
],
|
98 |
+
"source": [
|
99 |
+
"shards = list(\n",
|
100 |
+
" braceexpand.braceexpand(shards)\n",
|
101 |
+
") # better display for tqdm with known length"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "markdown",
|
106 |
+
"id": "75dba8e2",
|
107 |
+
"metadata": {},
|
108 |
+
"source": [
|
109 |
+
"## Load data"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "markdown",
|
114 |
+
"id": "a1e8fb95",
|
115 |
+
"metadata": {},
|
116 |
+
"source": [
|
117 |
+
"We load data using `webdataset`."
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": null,
|
123 |
+
"id": "9ef5de9e",
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [],
|
126 |
+
"source": [
|
127 |
+
"ds = (\n",
|
128 |
+
" wds.WebDataset(shards, handler=wds.warn_and_continue)\n",
|
129 |
+
" .decode(\"rgb\", handler=wds.warn_and_continue)\n",
|
130 |
+
" .to_tuple(\"jpg\", \"txt\") # assumes image is in `jpg` and caption in `txt`\n",
|
131 |
+
" .batched(total_bs) # load in batch per worker (faster)\n",
|
132 |
+
")"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "markdown",
|
137 |
+
"id": "90981824",
|
138 |
+
"metadata": {},
|
139 |
+
"source": [
|
140 |
+
"Note:\n",
|
141 |
+
"* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.\n",
|
142 |
+
"* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.\n",
|
143 |
+
"* you can also filter out some items using `select`."
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "markdown",
|
148 |
+
"id": "129c377d",
|
149 |
+
"metadata": {},
|
150 |
+
"source": [
|
151 |
+
"We can now inspect our data."
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "code",
|
156 |
+
"execution_count": null,
|
157 |
+
"id": "8cac98cb",
|
158 |
+
"metadata": {
|
159 |
+
"scrolled": true
|
160 |
+
},
|
161 |
+
"outputs": [],
|
162 |
+
"source": [
|
163 |
+
"%%time\n",
|
164 |
+
"images, captions = next(iter(ds))"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "code",
|
169 |
+
"execution_count": null,
|
170 |
+
"id": "cd268fbf",
|
171 |
+
"metadata": {},
|
172 |
+
"outputs": [],
|
173 |
+
"source": [
|
174 |
+
"images.shape"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "code",
|
179 |
+
"execution_count": null,
|
180 |
+
"id": "5acfc4d8",
|
181 |
+
"metadata": {},
|
182 |
+
"outputs": [],
|
183 |
+
"source": [
|
184 |
+
"captions[:10]"
|
185 |
+
]
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"cell_type": "code",
|
189 |
+
"execution_count": null,
|
190 |
+
"id": "c24693c0",
|
191 |
+
"metadata": {},
|
192 |
+
"outputs": [],
|
193 |
+
"source": [
|
194 |
+
"T.ToPILImage()(images[0].permute(2, 0, 1))"
|
195 |
+
]
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"cell_type": "markdown",
|
199 |
+
"id": "3059ffb1",
|
200 |
+
"metadata": {},
|
201 |
+
"source": [
|
202 |
+
"Finally we create our dataloader."
|
203 |
+
]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"cell_type": "code",
|
207 |
+
"execution_count": null,
|
208 |
+
"id": "c227c551",
|
209 |
+
"metadata": {},
|
210 |
+
"outputs": [],
|
211 |
+
"source": [
|
212 |
+
"dl = (\n",
|
213 |
+
" wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)\n",
|
214 |
+
") # avoid partial batch at the end of each worker"
|
215 |
+
]
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"cell_type": "markdown",
|
219 |
+
"id": "a354472b",
|
220 |
+
"metadata": {},
|
221 |
+
"source": [
|
222 |
+
"## Image encoder\n",
|
223 |
+
"\n",
|
224 |
+
"We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "code",
|
229 |
+
"execution_count": null,
|
230 |
+
"id": "47a8b818",
|
231 |
+
"metadata": {
|
232 |
+
"scrolled": true
|
233 |
+
},
|
234 |
+
"outputs": [],
|
235 |
+
"source": [
|
236 |
+
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
237 |
+
"from flax.jax_utils import replicate\n",
|
238 |
+
"\n",
|
239 |
+
"vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")\n",
|
240 |
+
"vqgan_params = replicate(vqgan.params)"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "markdown",
|
245 |
+
"id": "62ad01c3",
|
246 |
+
"metadata": {},
|
247 |
+
"source": [
|
248 |
+
"## Encoding"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"cell_type": "markdown",
|
253 |
+
"id": "20357f74",
|
254 |
+
"metadata": {},
|
255 |
+
"source": [
|
256 |
+
"Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`."
|
257 |
+
]
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"cell_type": "code",
|
261 |
+
"execution_count": null,
|
262 |
+
"id": "322a4619",
|
263 |
+
"metadata": {},
|
264 |
+
"outputs": [],
|
265 |
+
"source": [
|
266 |
+
"from flax.training.common_utils import shard\n",
|
267 |
+
"from functools import partial\n",
|
268 |
+
"\n",
|
269 |
+
"\n",
|
270 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
271 |
+
"def p_encode(batch, params):\n",
|
272 |
+
" # Not sure if we should `replicate` params, does not seem to have any effect\n",
|
273 |
+
" _, indices = vqgan.encode(batch, params=params)\n",
|
274 |
+
" return indices"
|
275 |
+
]
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"cell_type": "code",
|
279 |
+
"execution_count": null,
|
280 |
+
"id": "ff6c10d4",
|
281 |
+
"metadata": {},
|
282 |
+
"outputs": [],
|
283 |
+
"source": [
|
284 |
+
"import pandas as pd\n",
|
285 |
+
"\n",
|
286 |
+
"\n",
|
287 |
+
"def encode_dataset(dataloader, output_dir, save_frequency):\n",
|
288 |
+
" output_dir.mkdir(parents=True, exist_ok=True)\n",
|
289 |
+
" all_captions = []\n",
|
290 |
+
" all_encoding = []\n",
|
291 |
+
" n_file = 1\n",
|
292 |
+
" for idx, (images, captions) in enumerate(tqdm(dataloader)):\n",
|
293 |
+
" images = images.numpy()\n",
|
294 |
+
" n = len(images) // 8 * 8\n",
|
295 |
+
" if n != len(images):\n",
|
296 |
+
" # get the max number of images we can (multiple of 8)\n",
|
297 |
+
" print(f\"Different sizes {n} vs {len(images)}\")\n",
|
298 |
+
" images = images[:n]\n",
|
299 |
+
" captions = captions[:n]\n",
|
300 |
+
" if not len(captions):\n",
|
301 |
+
" print(f\"No images/captions in batch...\")\n",
|
302 |
+
" continue\n",
|
303 |
+
" images = shard(images)\n",
|
304 |
+
" encoded = p_encode(images, vqgan_params)\n",
|
305 |
+
" encoded = encoded.reshape(-1, encoded.shape[-1])\n",
|
306 |
+
" all_captions.extend(captions)\n",
|
307 |
+
" all_encoding.extend(encoded.tolist())\n",
|
308 |
+
"\n",
|
309 |
+
" # save files\n",
|
310 |
+
" if (idx + 1) % save_frequency == 0:\n",
|
311 |
+
" print(f\"Saving file {n_file}\")\n",
|
312 |
+
" batch_df = pd.DataFrame.from_dict(\n",
|
313 |
+
" {\"caption\": all_captions, \"encoding\": all_encoding}\n",
|
314 |
+
" )\n",
|
315 |
+
" batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")\n",
|
316 |
+
" all_captions = []\n",
|
317 |
+
" all_encoding = []\n",
|
318 |
+
" n_file += 1\n",
|
319 |
+
"\n",
|
320 |
+
" if len(all_captions):\n",
|
321 |
+
" print(f\"Saving final file {n_file}\")\n",
|
322 |
+
" batch_df = pd.DataFrame.from_dict(\n",
|
323 |
+
" {\"caption\": all_captions, \"encoding\": all_encoding}\n",
|
324 |
+
" )\n",
|
325 |
+
" batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"cell_type": "code",
|
330 |
+
"execution_count": null,
|
331 |
+
"id": "7704863d",
|
332 |
+
"metadata": {},
|
333 |
+
"outputs": [],
|
334 |
+
"source": [
|
335 |
+
"encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)"
|
336 |
+
]
|
337 |
+
},
|
338 |
+
{
|
339 |
+
"cell_type": "markdown",
|
340 |
+
"id": "8953dd84",
|
341 |
+
"metadata": {},
|
342 |
+
"source": [
|
343 |
+
"----"
|
344 |
+
]
|
345 |
+
}
|
346 |
+
],
|
347 |
+
"metadata": {
|
348 |
+
"interpreter": {
|
349 |
+
"hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
|
350 |
+
},
|
351 |
+
"kernelspec": {
|
352 |
+
"display_name": "Python 3 (ipykernel)",
|
353 |
+
"language": "python",
|
354 |
+
"name": "python3"
|
355 |
+
},
|
356 |
+
"language_info": {
|
357 |
+
"codemirror_mode": {
|
358 |
+
"name": "ipython",
|
359 |
+
"version": 3
|
360 |
+
},
|
361 |
+
"file_extension": ".py",
|
362 |
+
"mimetype": "text/x-python",
|
363 |
+
"name": "python",
|
364 |
+
"nbconvert_exporter": "python",
|
365 |
+
"pygments_lexer": "ipython3",
|
366 |
+
"version": "3.9.7"
|
367 |
+
}
|
368 |
+
},
|
369 |
+
"nbformat": 4,
|
370 |
+
"nbformat_minor": 5
|
371 |
+
}
|
tools/inference/inference_pipeline.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tools/inference/log_inference_samples.ipynb
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import tempfile\n",
|
11 |
+
"from functools import partial\n",
|
12 |
+
"import random\n",
|
13 |
+
"import numpy as np\n",
|
14 |
+
"from PIL import Image\n",
|
15 |
+
"from tqdm.notebook import tqdm\n",
|
16 |
+
"import jax\n",
|
17 |
+
"import jax.numpy as jnp\n",
|
18 |
+
"from flax.training.common_utils import shard, shard_prng_key\n",
|
19 |
+
"from flax.jax_utils import replicate\n",
|
20 |
+
"import wandb\n",
|
21 |
+
"from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
|
22 |
+
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
23 |
+
"from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel\n",
|
24 |
+
"from dalle_mini.text import TextNormalizer"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": null,
|
30 |
+
"id": "92f4557c-fd7f-4edc-81c2-de0b0a10c270",
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"run_ids = [\"63otg87g\"]\n",
|
35 |
+
"ENTITY, PROJECT = \"dalle-mini\", \"dalle-mini\" # used only for training run\n",
|
36 |
+
"VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
|
37 |
+
" \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
|
38 |
+
" \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\",\n",
|
39 |
+
")\n",
|
40 |
+
"latest_only = True # log only latest or all versions\n",
|
41 |
+
"suffix = \"\" # mainly for duplicate inference runs with a deleted version\n",
|
42 |
+
"add_clip_32 = False"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": null,
|
48 |
+
"id": "71f27b96-7e6c-4472-a2e4-e99a8fb67a72",
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"# model.generate parameters - Not used yet\n",
|
53 |
+
"gen_top_k = None\n",
|
54 |
+
"gen_top_p = None\n",
|
55 |
+
"temperature = None"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": null,
|
61 |
+
"id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"batch_size = 8\n",
|
66 |
+
"num_images = 128\n",
|
67 |
+
"top_k = 8\n",
|
68 |
+
"text_normalizer = TextNormalizer()\n",
|
69 |
+
"padding_item = \"NONE\"\n",
|
70 |
+
"seed = random.randint(0, 2 ** 32 - 1)\n",
|
71 |
+
"key = jax.random.PRNGKey(seed)\n",
|
72 |
+
"api = wandb.Api()"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"id": "c6a878fa-4bf5-4978-abb5-e235841d765b",
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [],
|
81 |
+
"source": [
|
82 |
+
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
83 |
+
"vqgan_params = replicate(vqgan.params)\n",
|
84 |
+
"\n",
|
85 |
+
"clip16 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
86 |
+
"processor16 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
87 |
+
"clip16_params = replicate(clip16.params)\n",
|
88 |
+
"\n",
|
89 |
+
"if add_clip_32:\n",
|
90 |
+
" clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
91 |
+
" processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
92 |
+
" clip32_params = replicate(clip32.params)"
|
93 |
+
]
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "code",
|
97 |
+
"execution_count": null,
|
98 |
+
"id": "a500dd07-dbc3-477d-80d4-2b73a3b83ef3",
|
99 |
+
"metadata": {},
|
100 |
+
"outputs": [],
|
101 |
+
"source": [
|
102 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
103 |
+
"def p_decode(indices, params):\n",
|
104 |
+
" return vqgan.decode_code(indices, params=params)\n",
|
105 |
+
"\n",
|
106 |
+
"\n",
|
107 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
108 |
+
"def p_clip16(inputs, params):\n",
|
109 |
+
" logits = clip16(params=params, **inputs).logits_per_image\n",
|
110 |
+
" return logits\n",
|
111 |
+
"\n",
|
112 |
+
"\n",
|
113 |
+
"if add_clip_32:\n",
|
114 |
+
"\n",
|
115 |
+
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
116 |
+
" def p_clip32(inputs, params):\n",
|
117 |
+
" logits = clip32(params=params, **inputs).logits_per_image\n",
|
118 |
+
" return logits"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": null,
|
124 |
+
"id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
|
125 |
+
"metadata": {},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"with open(\"samples.txt\", encoding=\"utf8\") as f:\n",
|
129 |
+
" samples = [l.strip() for l in f.readlines()]\n",
|
130 |
+
" # make list multiple of batch_size by adding elements\n",
|
131 |
+
" samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
|
132 |
+
" samples.extend(samples_to_add)\n",
|
133 |
+
" # reshape\n",
|
134 |
+
" samples = [samples[i : i + batch_size] for i in range(0, len(samples), batch_size)]"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": null,
|
140 |
+
"id": "f3e02d9d-4ee1-49e7-a7bc-4d8b139e9614",
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"def get_artifact_versions(run_id, latest_only=False):\n",
|
145 |
+
" try:\n",
|
146 |
+
" if latest_only:\n",
|
147 |
+
" return [\n",
|
148 |
+
" api.artifact(\n",
|
149 |
+
" type=\"bart_model\", name=f\"{ENTITY}/{PROJECT}/model-{run_id}:latest\"\n",
|
150 |
+
" )\n",
|
151 |
+
" ]\n",
|
152 |
+
" else:\n",
|
153 |
+
" return api.artifact_versions(\n",
|
154 |
+
" type_name=\"bart_model\",\n",
|
155 |
+
" name=f\"{ENTITY}/{PROJECT}/model-{run_id}\",\n",
|
156 |
+
" per_page=10000,\n",
|
157 |
+
" )\n",
|
158 |
+
" except:\n",
|
159 |
+
" return []"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": null,
|
165 |
+
"id": "f0d7ed17-7abb-4a31-ab3c-a12b9039a570",
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": [
|
169 |
+
"def get_training_config(run_id):\n",
|
170 |
+
" training_run = api.run(f\"{ENTITY}/{PROJECT}/{run_id}\")\n",
|
171 |
+
" config = training_run.config\n",
|
172 |
+
" return config"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": null,
|
178 |
+
"id": "7e784a43-626d-4e8d-9e47-a23775b2f35f",
|
179 |
+
"metadata": {},
|
180 |
+
"outputs": [],
|
181 |
+
"source": [
|
182 |
+
"# retrieve inference run details\n",
|
183 |
+
"def get_last_inference_version(run_id):\n",
|
184 |
+
" try:\n",
|
185 |
+
" inference_run = api.run(f\"dalle-mini/dalle-mini/{run_id}-clip16{suffix}\")\n",
|
186 |
+
" return inference_run.summary.get(\"version\", None)\n",
|
187 |
+
" except:\n",
|
188 |
+
" return None"
|
189 |
+
]
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"cell_type": "code",
|
193 |
+
"execution_count": null,
|
194 |
+
"id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
|
195 |
+
"metadata": {},
|
196 |
+
"outputs": [],
|
197 |
+
"source": [
|
198 |
+
"# compile functions - needed only once per run\n",
|
199 |
+
"def pmap_model_function(model):\n",
|
200 |
+
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
201 |
+
" def _generate(tokenized_prompt, key, params):\n",
|
202 |
+
" return model.generate(\n",
|
203 |
+
" **tokenized_prompt,\n",
|
204 |
+
" do_sample=True,\n",
|
205 |
+
" num_beams=1,\n",
|
206 |
+
" prng_key=key,\n",
|
207 |
+
" params=params,\n",
|
208 |
+
" top_k=gen_top_k,\n",
|
209 |
+
" top_p=gen_top_p\n",
|
210 |
+
" )\n",
|
211 |
+
"\n",
|
212 |
+
" return _generate"
|
213 |
+
]
|
214 |
+
},
|
215 |
+
{
|
216 |
+
"cell_type": "code",
|
217 |
+
"execution_count": null,
|
218 |
+
"id": "23b2444c-67a9-44d7-abd1-187ed83a9431",
|
219 |
+
"metadata": {},
|
220 |
+
"outputs": [],
|
221 |
+
"source": [
|
222 |
+
"run_id = run_ids[0]\n",
|
223 |
+
"# TODO: loop over runs"
|
224 |
+
]
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"cell_type": "code",
|
228 |
+
"execution_count": null,
|
229 |
+
"id": "bba70f33-af8b-4eb3-9973-7be672301a0b",
|
230 |
+
"metadata": {},
|
231 |
+
"outputs": [],
|
232 |
+
"source": [
|
233 |
+
"artifact_versions = get_artifact_versions(run_id, latest_only)\n",
|
234 |
+
"last_inference_version = get_last_inference_version(run_id)\n",
|
235 |
+
"training_config = get_training_config(run_id)\n",
|
236 |
+
"run = None\n",
|
237 |
+
"p_generate = None\n",
|
238 |
+
"model_files = [\n",
|
239 |
+
" \"config.json\",\n",
|
240 |
+
" \"flax_model.msgpack\",\n",
|
241 |
+
" \"merges.txt\",\n",
|
242 |
+
" \"special_tokens_map.json\",\n",
|
243 |
+
" \"tokenizer.json\",\n",
|
244 |
+
" \"tokenizer_config.json\",\n",
|
245 |
+
" \"vocab.json\",\n",
|
246 |
+
"]\n",
|
247 |
+
"for artifact in artifact_versions:\n",
|
248 |
+
" print(f\"Processing artifact: {artifact.name}\")\n",
|
249 |
+
" version = int(artifact.version[1:])\n",
|
250 |
+
" results16, results32 = [], []\n",
|
251 |
+
" columns = [\"Caption\"] + [f\"Image {i+1}\" for i in range(top_k)]\n",
|
252 |
+
"\n",
|
253 |
+
" if latest_only:\n",
|
254 |
+
" assert last_inference_version is None or version > last_inference_version\n",
|
255 |
+
" else:\n",
|
256 |
+
" if last_inference_version is None:\n",
|
257 |
+
" # we should start from v0\n",
|
258 |
+
" assert version == 0\n",
|
259 |
+
" elif version <= last_inference_version:\n",
|
260 |
+
" print(\n",
|
261 |
+
" f\"v{version} has already been logged (versions logged up to v{last_inference_version}\"\n",
|
262 |
+
" )\n",
|
263 |
+
" else:\n",
|
264 |
+
" # check we are logging the correct version\n",
|
265 |
+
" assert version == last_inference_version + 1\n",
|
266 |
+
"\n",
|
267 |
+
" # start/resume corresponding run\n",
|
268 |
+
" if run is None:\n",
|
269 |
+
" run = wandb.init(\n",
|
270 |
+
" job_type=\"inference\",\n",
|
271 |
+
" entity=\"dalle-mini\",\n",
|
272 |
+
" project=\"dalle-mini\",\n",
|
273 |
+
" config=training_config,\n",
|
274 |
+
" id=f\"{run_id}-clip16{suffix}\",\n",
|
275 |
+
" resume=\"allow\",\n",
|
276 |
+
" )\n",
|
277 |
+
"\n",
|
278 |
+
" # work in temporary directory\n",
|
279 |
+
" with tempfile.TemporaryDirectory() as tmp:\n",
|
280 |
+
"\n",
|
281 |
+
" # download model files\n",
|
282 |
+
" artifact = run.use_artifact(artifact)\n",
|
283 |
+
" for f in model_files:\n",
|
284 |
+
" artifact.get_path(f).download(tmp)\n",
|
285 |
+
"\n",
|
286 |
+
" # load tokenizer and model\n",
|
287 |
+
" tokenizer = BartTokenizer.from_pretrained(tmp)\n",
|
288 |
+
" model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)\n",
|
289 |
+
" model_params = replicate(model.params)\n",
|
290 |
+
"\n",
|
291 |
+
" # pmap model function needs to happen only once per model config\n",
|
292 |
+
" if p_generate is None:\n",
|
293 |
+
" p_generate = pmap_model_function(model)\n",
|
294 |
+
"\n",
|
295 |
+
" # process one batch of captions\n",
|
296 |
+
" for batch in tqdm(samples):\n",
|
297 |
+
" processed_prompts = (\n",
|
298 |
+
" [text_normalizer(x) for x in batch]\n",
|
299 |
+
" if model.config.normalize_text\n",
|
300 |
+
" else list(batch)\n",
|
301 |
+
" )\n",
|
302 |
+
"\n",
|
303 |
+
" # repeat the prompts to distribute over each device and tokenize\n",
|
304 |
+
" processed_prompts = processed_prompts * jax.device_count()\n",
|
305 |
+
" tokenized_prompt = tokenizer(\n",
|
306 |
+
" processed_prompts,\n",
|
307 |
+
" return_tensors=\"jax\",\n",
|
308 |
+
" padding=\"max_length\",\n",
|
309 |
+
" truncation=True,\n",
|
310 |
+
" max_length=128,\n",
|
311 |
+
" ).data\n",
|
312 |
+
" tokenized_prompt = shard(tokenized_prompt)\n",
|
313 |
+
"\n",
|
314 |
+
" # generate images\n",
|
315 |
+
" images = []\n",
|
316 |
+
" pbar = tqdm(\n",
|
317 |
+
" range(num_images // jax.device_count()),\n",
|
318 |
+
" desc=\"Generating Images\",\n",
|
319 |
+
" leave=True,\n",
|
320 |
+
" )\n",
|
321 |
+
" for i in pbar:\n",
|
322 |
+
" key, subkey = jax.random.split(key)\n",
|
323 |
+
" encoded_images = p_generate(\n",
|
324 |
+
" tokenized_prompt, shard_prng_key(subkey), model_params\n",
|
325 |
+
" )\n",
|
326 |
+
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
327 |
+
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
328 |
+
" decoded_images = decoded_images.clip(0.0, 1.0).reshape(\n",
|
329 |
+
" (-1, 256, 256, 3)\n",
|
330 |
+
" )\n",
|
331 |
+
" for img in decoded_images:\n",
|
332 |
+
" images.append(\n",
|
333 |
+
" Image.fromarray(np.asarray(img * 255, dtype=np.uint8))\n",
|
334 |
+
" )\n",
|
335 |
+
"\n",
|
336 |
+
" def add_clip_results(results, processor, p_clip, clip_params):\n",
|
337 |
+
" clip_inputs = processor(\n",
|
338 |
+
" text=batch,\n",
|
339 |
+
" images=images,\n",
|
340 |
+
" return_tensors=\"np\",\n",
|
341 |
+
" padding=\"max_length\",\n",
|
342 |
+
" max_length=77,\n",
|
343 |
+
" truncation=True,\n",
|
344 |
+
" ).data\n",
|
345 |
+
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
346 |
+
" images_per_prompt_indices = np.asarray(\n",
|
347 |
+
" range(0, len(images), batch_size)\n",
|
348 |
+
" )\n",
|
349 |
+
" clip_inputs[\"pixel_values\"] = jnp.concatenate(\n",
|
350 |
+
" list(\n",
|
351 |
+
" clip_inputs[\"pixel_values\"][images_per_prompt_indices + i]\n",
|
352 |
+
" for i in range(batch_size)\n",
|
353 |
+
" )\n",
|
354 |
+
" )\n",
|
355 |
+
" clip_inputs = shard(clip_inputs)\n",
|
356 |
+
" logits = p_clip(clip_inputs, clip_params)\n",
|
357 |
+
" logits = logits.reshape(-1, num_images)\n",
|
358 |
+
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
359 |
+
" logits = jax.device_get(logits)\n",
|
360 |
+
" # add to results table\n",
|
361 |
+
" for i, (idx, scores, sample) in enumerate(\n",
|
362 |
+
" zip(top_scores, logits, batch)\n",
|
363 |
+
" ):\n",
|
364 |
+
" if sample == padding_item:\n",
|
365 |
+
" continue\n",
|
366 |
+
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
367 |
+
" top_images = [\n",
|
368 |
+
" wandb.Image(cur_images[x], caption=f\"Score: {scores[x]:.2f}\")\n",
|
369 |
+
" for x in idx\n",
|
370 |
+
" ]\n",
|
371 |
+
" results.append([sample] + top_images)\n",
|
372 |
+
"\n",
|
373 |
+
" # get clip scores\n",
|
374 |
+
" pbar.set_description(\"Calculating CLIP 16 scores\")\n",
|
375 |
+
" add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
|
376 |
+
"\n",
|
377 |
+
" # get clip 32 scores\n",
|
378 |
+
" if add_clip_32:\n",
|
379 |
+
" pbar.set_description(\"Calculating CLIP 32 scores\")\n",
|
380 |
+
" add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
|
381 |
+
"\n",
|
382 |
+
" pbar.close()\n",
|
383 |
+
"\n",
|
384 |
+
" # log results\n",
|
385 |
+
" table = wandb.Table(columns=columns, data=results16)\n",
|
386 |
+
" run.log({\"Samples\": table, \"version\": version})\n",
|
387 |
+
" wandb.finish()\n",
|
388 |
+
"\n",
|
389 |
+
" if add_clip_32:\n",
|
390 |
+
" run = wandb.init(\n",
|
391 |
+
" job_type=\"inference\",\n",
|
392 |
+
" entity=\"dalle-mini\",\n",
|
393 |
+
" project=\"dalle-mini\",\n",
|
394 |
+
" config=training_config,\n",
|
395 |
+
" id=f\"{run_id}-clip32{suffix}\",\n",
|
396 |
+
" resume=\"allow\",\n",
|
397 |
+
" )\n",
|
398 |
+
" table = wandb.Table(columns=columns, data=results32)\n",
|
399 |
+
" run.log({\"Samples\": table, \"version\": version})\n",
|
400 |
+
" wandb.finish()\n",
|
401 |
+
" run = None # ensure we don't log on this run"
|
402 |
+
]
|
403 |
+
},
|
404 |
+
{
|
405 |
+
"cell_type": "code",
|
406 |
+
"execution_count": null,
|
407 |
+
"id": "415d3f54-7226-43de-9eea-4283a948dc93",
|
408 |
+
"metadata": {},
|
409 |
+
"outputs": [],
|
410 |
+
"source": []
|
411 |
+
}
|
412 |
+
],
|
413 |
+
"metadata": {
|
414 |
+
"kernelspec": {
|
415 |
+
"display_name": "Python 3 (ipykernel)",
|
416 |
+
"language": "python",
|
417 |
+
"name": "python3"
|
418 |
+
},
|
419 |
+
"language_info": {
|
420 |
+
"codemirror_mode": {
|
421 |
+
"name": "ipython",
|
422 |
+
"version": 3
|
423 |
+
},
|
424 |
+
"file_extension": ".py",
|
425 |
+
"mimetype": "text/x-python",
|
426 |
+
"name": "python",
|
427 |
+
"nbconvert_exporter": "python",
|
428 |
+
"pygments_lexer": "ipython3",
|
429 |
+
"version": "3.9.7"
|
430 |
+
}
|
431 |
+
},
|
432 |
+
"nbformat": 4,
|
433 |
+
"nbformat_minor": 5
|
434 |
+
}
|
tools/inference/samples.txt
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
t-shirt, size M
|
2 |
+
flower dress, size M
|
3 |
+
white snow covered mountain under blue sky during daytime
|
4 |
+
aerial view of the beach during daytime
|
5 |
+
aerial view of the beach at night
|
6 |
+
a beautiful sunset at a beach with a shell on the shore
|
7 |
+
a farmhouse surrounded by beautiful flowers
|
8 |
+
sunset over green mountains
|
9 |
+
a photo of san francisco golden gate bridge
|
10 |
+
painting of an oniric forest glade surrounded by tall trees
|
11 |
+
a graphite sketch of a gothic cathedral
|
12 |
+
a graphite sketch of Elon Musk
|
13 |
+
still life in the style of Kandinsky
|
14 |
+
still life in the style of Picasso
|
15 |
+
a colorful stairway to heaven
|
16 |
+
a background consisting of colors blue, green, and red
|
17 |
+
Mohammed Ali and Mike Tyson in a match
|
18 |
+
Pele and Maradona in a match
|
19 |
+
view of Mars from space
|
20 |
+
a picture of the Eiffel tower on the moon
|
21 |
+
a picture of the Eiffel tower on the moon, Earth is in the background
|
22 |
+
watercolor of the Eiffel tower on the moon
|
23 |
+
the moon is a skull
|
24 |
+
epic sword fight
|
25 |
+
underwater cathedral
|
26 |
+
a photo of a fantasy version of New York City
|
27 |
+
a picture of fantasy kingdoms
|
28 |
+
a volcano erupting next to San Francisco golden gate bridge
|
29 |
+
Paris in a far future, futuristic Paris
|
30 |
+
real painting of an alien from Monet
|
31 |
+
the communist statue of liberty
|
32 |
+
robots taking control over humans
|
33 |
+
illustration of an astronaut in a space suit playing guitar
|
34 |
+
a clown wearing a spacesuit floating in space
|
35 |
+
a dog playing with a ball
|
36 |
+
a cat sits on top of an alligator
|
37 |
+
a very cute cat laying by a big bike
|
38 |
+
a rat holding a red lightsaber in a white background
|
39 |
+
a very cute giraffe making a funny face
|
40 |
+
A unicorn is passing by a rainbow in a field of flowers
|
41 |
+
an elephant made of carrots
|
42 |
+
an elephant on a unicycle during a circus
|
43 |
+
photography of a penguin watching television
|
44 |
+
a penguin is walking on the Moon, Earth is in the background
|
45 |
+
a penguin standing on a tower of books holds onto a rope from a helicopter
|
46 |
+
rat wearing a crown
|
47 |
+
looking into the sky, 10 airplanes are seen overhead
|
48 |
+
shelves filled with books and alchemy potion bottles
|
49 |
+
this is a detailed high-resolution scan of a human brain
|
50 |
+
a restaurant menu
|
51 |
+
a bottle of coca-cola on a table
|
52 |
+
a peanut
|
53 |
+
a cross-section view of a walnut
|
54 |
+
a living room with two white armchairs and a painting of the collosseum. The painting is mounted above a modern fireplace.
|
55 |
+
a long line of alternating green and red blocks
|
56 |
+
a long line of green blocks on a beach at subset
|
57 |
+
a long line of peaches on a beach at sunset
|
58 |
+
a picture of a castle from minecraft
|
59 |
+
a cute pikachu teapot
|
60 |
+
an illustration of pikachu sitting on a bench eating an ice cream
|
61 |
+
mario is jumping over a zebra
|
62 |
+
famous anime hero
|
63 |
+
star wars concept art
|
64 |
+
Cartoon of a carrot with big eyes
|
65 |
+
a cartoon of a superhero bear
|
66 |
+
an illustration of a cute skeleton wearing a blue hoodie
|
67 |
+
illustration of a baby shark swimming around corals
|
68 |
+
an illustration of an avocado in a beanie riding a motorcycle
|
69 |
+
logo of a robot wearing glasses and reading a book
|
70 |
+
illustration of a cactus lifting weigths
|
71 |
+
logo of a cactus lifting weights
|
72 |
+
a photo of a camera from the future
|
73 |
+
a skeleton with the shape of a spider
|
74 |
+
a collection of glasses is sitting on a table
|
75 |
+
a painting of a capybara sitting on a mountain during fall in surrealist style
|
76 |
+
a pentagonal green clock
|
77 |
+
a small red block sitting on a large green block
|
78 |
+
a storefront that has the word 'openai' written on it
|
79 |
+
a tatoo of a black broccoli
|
80 |
+
a variety of clocks is sitting on a table
|
81 |
+
a table has a train model on it with other cars and things
|
82 |
+
a pixel art illustration of an eagle sitting in a field in the afternoon
|
83 |
+
an emoji of a baby fox wearing a blue hat, green gloves, red shirt, and yellow pants
|
84 |
+
an emoji of a baby penguin wearing a blue hat, blue gloves, red shirt, and green pants
|
85 |
+
an extreme close-up view of a capybara sitting in a field
|
86 |
+
an illustration of a baby cucumber with a mustache playing chess
|
87 |
+
an illustration of a baby daikon radish in a tutu walking a dog
|
88 |
+
an illustration of a baby hedgehog in a cape staring at its reflection in a mirror
|
89 |
+
an illustration of a baby panda with headphones holding an umbrella in the rain
|
90 |
+
urinals are lined up in a jungle
|
91 |
+
a muscular banana sitting upright on a bench smoking watching a banana on television, high definition photography
|
92 |
+
a human face
|
93 |
+
a person is holding a phone and a waterbottle, running a marathon
|
94 |
+
a child eating a birthday cake near some balloons
|
95 |
+
Young woman riding her bike through the forest
|
96 |
+
the best soccer team of the world
|
97 |
+
the best football team of the world
|
98 |
+
the best basketball team of the world
|
99 |
+
happy, happiness
|
100 |
+
sad, sadness
|
101 |
+
the representation of infinity
|
102 |
+
the end of the world
|
103 |
+
the last sunrise on earth
|
104 |
+
a portrait of a nightmare creature watching at you
|
105 |
+
an avocado armchair
|
106 |
+
an armchair in the shape of an avocado
|
107 |
+
illustration of an avocado armchair
|
108 |
+
illustration of an armchair in the shape of an avocado
|
109 |
+
logo of an avocado armchair
|
110 |
+
an avocado armchair flying into space
|
111 |
+
a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps
|
112 |
+
an illustration of an avocado in a christmas sweater staring at its reflection in a mirror
|
113 |
+
illustration of an avocado armchair getting married to a pineapple
|
114 |
+
half human half cat
|
115 |
+
half human half dog
|
116 |
+
half human half pen
|
117 |
+
half human half garbage
|
118 |
+
half human half avocado
|
119 |
+
half human half Eiffel tower
|
120 |
+
a propaganda poster for transhumanism
|
121 |
+
a propaganda poster for building a space elevator
|
122 |
+
a beautiful epic fantasy painting of a space elevator
|
123 |
+
a transformer architecture
|
124 |
+
a transformer in real life
|
{seq2seq → tools/train}/sweep.yaml
RENAMED
@@ -1,6 +1,6 @@
|
|
1 |
-
program:
|
2 |
-
entity:
|
3 |
-
project:
|
4 |
method: random
|
5 |
metric:
|
6 |
name: eval/loss
|
@@ -8,36 +8,47 @@ metric:
|
|
8 |
parameters:
|
9 |
learning_rate:
|
10 |
distribution: log_uniform
|
11 |
-
# from exp(min) to exp(max)
|
12 |
-
min: -
|
13 |
-
max: -5
|
14 |
gradient_accumulation_steps:
|
15 |
value: 8
|
16 |
warmup_steps:
|
17 |
-
|
18 |
-
|
19 |
command:
|
20 |
- python3
|
21 |
- ${program}
|
22 |
-
- "--
|
23 |
-
- "/
|
24 |
-
- "--
|
25 |
-
- "/
|
26 |
-
- "--
|
27 |
-
- "
|
28 |
-
- "--
|
29 |
-
- "--
|
30 |
-
- "--
|
31 |
-
-
|
32 |
-
- "--
|
33 |
-
-
|
|
|
|
|
34 |
- "--per_device_train_batch_size"
|
35 |
- 56
|
36 |
- "--per_device_eval_batch_size"
|
37 |
- 56
|
38 |
-
- "--
|
39 |
-
- 80
|
40 |
-
- "--no_decay"
|
41 |
- "--do_train"
|
42 |
- "--do_eval"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
- ${args}
|
|
|
1 |
+
program: train.py
|
2 |
+
entity: dalle-mini
|
3 |
+
project: dalle-mini
|
4 |
method: random
|
5 |
metric:
|
6 |
name: eval/loss
|
|
|
8 |
parameters:
|
9 |
learning_rate:
|
10 |
distribution: log_uniform
|
11 |
+
# from exp(min) to exp(max)
|
12 |
+
min: -6.9
|
13 |
+
max: -3.5
|
14 |
gradient_accumulation_steps:
|
15 |
value: 8
|
16 |
warmup_steps:
|
17 |
+
value: 4000
|
18 |
+
#TODO: outdated command
|
19 |
command:
|
20 |
- python3
|
21 |
- ${program}
|
22 |
+
- "--tokenizer_name"
|
23 |
+
- "boris/dalle-mini-tokenizer"
|
24 |
+
- "--config_name"
|
25 |
+
- "facebook/bart-large-cnn"
|
26 |
+
- "--dataset_repo_or_path"
|
27 |
+
- "boris/gis_vqgan_f16_16384"
|
28 |
+
- "--streaming"
|
29 |
+
- "--use_auth_token"
|
30 |
+
- "--image_vocab_size"
|
31 |
+
- 16384
|
32 |
+
- "--image_length"
|
33 |
+
- 256
|
34 |
+
- "--normalize_text"
|
35 |
+
- True
|
36 |
- "--per_device_train_batch_size"
|
37 |
- 56
|
38 |
- "--per_device_eval_batch_size"
|
39 |
- 56
|
40 |
+
- "--adafactor"
|
|
|
|
|
41 |
- "--do_train"
|
42 |
- "--do_eval"
|
43 |
+
- "--num_train_epochs"
|
44 |
+
- 1
|
45 |
+
- "--logging_steps"
|
46 |
+
- 40
|
47 |
+
- "--eval_steps"
|
48 |
+
- 800
|
49 |
+
- "--output_dir"
|
50 |
+
- "./output"
|
51 |
+
- "--overwrite_output_dir"
|
52 |
+
- "--max_train_samples"
|
53 |
+
- 10000000
|
54 |
- ${args}
|
tools/train/train.py
ADDED
@@ -0,0 +1,857 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2021 The HuggingFace Team All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for seq2seq, text to image.
|
18 |
+
Script adapted from run_summarization_flax.py
|
19 |
+
"""
|
20 |
+
|
21 |
+
import json
|
22 |
+
import logging
|
23 |
+
import os
|
24 |
+
import sys
|
25 |
+
import time
|
26 |
+
from dataclasses import asdict, dataclass, field
|
27 |
+
from pathlib import Path
|
28 |
+
from typing import Callable, Optional
|
29 |
+
|
30 |
+
import datasets
|
31 |
+
import jax
|
32 |
+
import jax.numpy as jnp
|
33 |
+
import optax
|
34 |
+
import transformers
|
35 |
+
import wandb
|
36 |
+
from datasets import Dataset
|
37 |
+
from flax import jax_utils, traverse_util
|
38 |
+
from flax.jax_utils import unreplicate
|
39 |
+
from flax.serialization import from_bytes, to_bytes
|
40 |
+
from flax.training import train_state
|
41 |
+
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
42 |
+
from tqdm import tqdm
|
43 |
+
from transformers import AutoTokenizer, HfArgumentParser
|
44 |
+
from transformers.models.bart.modeling_flax_bart import BartConfig
|
45 |
+
|
46 |
+
from dalle_mini.data import Dataset
|
47 |
+
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
48 |
+
|
49 |
+
logger = logging.getLogger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
@dataclass
|
53 |
+
class ModelArguments:
|
54 |
+
"""
|
55 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
56 |
+
"""
|
57 |
+
|
58 |
+
model_name_or_path: Optional[str] = field(
|
59 |
+
default=None,
|
60 |
+
metadata={
|
61 |
+
"help": "The model checkpoint for weights initialization."
|
62 |
+
"Don't set if you want to train a model from scratch."
|
63 |
+
},
|
64 |
+
)
|
65 |
+
config_name: Optional[str] = field(
|
66 |
+
default=None,
|
67 |
+
metadata={
|
68 |
+
"help": "Pretrained config name or path if not the same as model_name"
|
69 |
+
},
|
70 |
+
)
|
71 |
+
image_vocab_size: Optional[int] = field(
|
72 |
+
default=None,
|
73 |
+
metadata={"help": "Vocab size of image encoder"},
|
74 |
+
)
|
75 |
+
image_length: Optional[int] = field(
|
76 |
+
default=None,
|
77 |
+
metadata={"help": "Number of tokens per image"},
|
78 |
+
)
|
79 |
+
tokenizer_name: Optional[str] = field(
|
80 |
+
default=None,
|
81 |
+
metadata={
|
82 |
+
"help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
|
83 |
+
},
|
84 |
+
)
|
85 |
+
normalize_text: Optional[bool] = field(
|
86 |
+
default=None,
|
87 |
+
metadata={
|
88 |
+
"help": "Whether to normalize text or not. By default, we refer to base model or don't normalize for new models."
|
89 |
+
},
|
90 |
+
)
|
91 |
+
dtype: Optional[str] = field(
|
92 |
+
default="float32",
|
93 |
+
metadata={
|
94 |
+
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
95 |
+
},
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
@dataclass
|
100 |
+
class DataTrainingArguments:
|
101 |
+
"""
|
102 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
103 |
+
"""
|
104 |
+
|
105 |
+
text_column: Optional[str] = field(
|
106 |
+
default="caption",
|
107 |
+
metadata={
|
108 |
+
"help": "The name of the column in the datasets containing the full texts (for summarization)."
|
109 |
+
},
|
110 |
+
)
|
111 |
+
encoding_column: Optional[str] = field(
|
112 |
+
default="encoding",
|
113 |
+
metadata={
|
114 |
+
"help": "The name of the column in the datasets containing the image encodings."
|
115 |
+
},
|
116 |
+
)
|
117 |
+
dataset_repo_or_path: str = field(
|
118 |
+
default=None,
|
119 |
+
metadata={"help": "The dataset repository containing encoded files."},
|
120 |
+
)
|
121 |
+
train_file: Optional[str] = field(
|
122 |
+
default=None,
|
123 |
+
metadata={"help": "The input training data file (glob acceptable)."},
|
124 |
+
)
|
125 |
+
validation_file: Optional[str] = field(
|
126 |
+
default=None,
|
127 |
+
metadata={"help": "An optional input evaluation data file (glob acceptable)."},
|
128 |
+
)
|
129 |
+
dataset_type: str = field(
|
130 |
+
default="datasets",
|
131 |
+
metadata={"help": "Either 🤗 'dataset' (default) or 'webdataset'."},
|
132 |
+
)
|
133 |
+
# data loading should not be a bottleneck so we use "streaming" mode by default
|
134 |
+
streaming: bool = field(
|
135 |
+
default=True,
|
136 |
+
metadata={"help": "Whether to stream the dataset."},
|
137 |
+
)
|
138 |
+
use_auth_token: bool = field(
|
139 |
+
default=False,
|
140 |
+
metadata={
|
141 |
+
"help": "Whether to use the authentication token for private datasets."
|
142 |
+
},
|
143 |
+
)
|
144 |
+
max_source_length: Optional[int] = field(
|
145 |
+
default=128,
|
146 |
+
metadata={
|
147 |
+
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
148 |
+
"than this will be truncated, sequences shorter will be padded."
|
149 |
+
},
|
150 |
+
)
|
151 |
+
max_train_samples: Optional[int] = field(
|
152 |
+
default=None,
|
153 |
+
metadata={
|
154 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
155 |
+
"value if set."
|
156 |
+
},
|
157 |
+
)
|
158 |
+
max_eval_samples: Optional[int] = field(
|
159 |
+
default=None,
|
160 |
+
metadata={
|
161 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
162 |
+
"value if set."
|
163 |
+
},
|
164 |
+
)
|
165 |
+
preprocessing_num_workers: Optional[int] = field(
|
166 |
+
default=None,
|
167 |
+
metadata={
|
168 |
+
"help": "The number of processes to use for the preprocessing. Not used in streaming mode."
|
169 |
+
},
|
170 |
+
)
|
171 |
+
overwrite_cache: bool = field(
|
172 |
+
default=False,
|
173 |
+
metadata={
|
174 |
+
"help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
|
175 |
+
},
|
176 |
+
)
|
177 |
+
# default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
|
178 |
+
seed_dataset: int = field(
|
179 |
+
default=None,
|
180 |
+
metadata={
|
181 |
+
"help": "Random seed for the dataset that will be set at the beginning of training."
|
182 |
+
},
|
183 |
+
)
|
184 |
+
|
185 |
+
def __post_init__(self):
|
186 |
+
if self.dataset_repo_or_path is None:
|
187 |
+
raise ValueError("Need a dataset repository or path.")
|
188 |
+
|
189 |
+
|
190 |
+
@dataclass
|
191 |
+
class TrainingArguments:
|
192 |
+
"""
|
193 |
+
Arguments pertaining to training parameters.
|
194 |
+
"""
|
195 |
+
|
196 |
+
output_dir: str = field(
|
197 |
+
metadata={
|
198 |
+
"help": "The output directory where the model predictions and checkpoints will be written."
|
199 |
+
},
|
200 |
+
)
|
201 |
+
overwrite_output_dir: bool = field(
|
202 |
+
default=False,
|
203 |
+
metadata={
|
204 |
+
"help": (
|
205 |
+
"Overwrite the content of the output directory. "
|
206 |
+
"Use this to continue training if output_dir points to a checkpoint directory."
|
207 |
+
)
|
208 |
+
},
|
209 |
+
)
|
210 |
+
|
211 |
+
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
212 |
+
do_eval: bool = field(
|
213 |
+
default=False, metadata={"help": "Whether to run eval on the dev set."}
|
214 |
+
)
|
215 |
+
|
216 |
+
per_device_train_batch_size: int = field(
|
217 |
+
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
218 |
+
)
|
219 |
+
per_device_eval_batch_size: int = field(
|
220 |
+
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
221 |
+
)
|
222 |
+
|
223 |
+
gradient_accumulation_steps: int = field(
|
224 |
+
default=1,
|
225 |
+
metadata={
|
226 |
+
"help": "Number of updates steps to accumulate before performing a backward/update pass."
|
227 |
+
},
|
228 |
+
)
|
229 |
+
|
230 |
+
learning_rate: float = field(
|
231 |
+
default=5e-5, metadata={"help": "The initial learning rate."}
|
232 |
+
)
|
233 |
+
adafactor: bool = field(
|
234 |
+
default=False,
|
235 |
+
metadata={"help": "Whether or not to replace AdamW by Adafactor."},
|
236 |
+
)
|
237 |
+
weight_decay: float = field(
|
238 |
+
default=None, metadata={"help": "Weight decay if we apply some."}
|
239 |
+
)
|
240 |
+
adam_beta1: float = field(
|
241 |
+
default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}
|
242 |
+
)
|
243 |
+
adam_beta2: float = field(
|
244 |
+
default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}
|
245 |
+
)
|
246 |
+
adam_epsilon: float = field(
|
247 |
+
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
248 |
+
)
|
249 |
+
max_grad_norm: float = field(
|
250 |
+
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
251 |
+
)
|
252 |
+
use_decay: bool = field(
|
253 |
+
default=False,
|
254 |
+
metadata={"help": "Whether to use decay in the learning rate scheduler."},
|
255 |
+
)
|
256 |
+
|
257 |
+
num_train_epochs: float = field(
|
258 |
+
default=3.0, metadata={"help": "Total number of training epochs to perform."}
|
259 |
+
)
|
260 |
+
warmup_steps: int = field(
|
261 |
+
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
262 |
+
)
|
263 |
+
|
264 |
+
logging_steps: int = field(
|
265 |
+
default=40, metadata={"help": "Log every X updates steps."}
|
266 |
+
)
|
267 |
+
eval_steps: int = field(
|
268 |
+
default=400, metadata={"help": "Run an evaluation every X steps."}
|
269 |
+
)
|
270 |
+
save_steps: int = field(
|
271 |
+
default=4000, metadata={"help": "Save checkpoint every X updates steps."}
|
272 |
+
)
|
273 |
+
log_model: bool = field(
|
274 |
+
default=False,
|
275 |
+
metadata={"help": "Log model to wandb at `save_steps` frequency."},
|
276 |
+
)
|
277 |
+
|
278 |
+
seed_model: int = field(
|
279 |
+
default=42,
|
280 |
+
metadata={
|
281 |
+
"help": "Random seed for the model that will be set at the beginning of training."
|
282 |
+
},
|
283 |
+
)
|
284 |
+
|
285 |
+
push_to_hub: bool = field(
|
286 |
+
default=False,
|
287 |
+
metadata={
|
288 |
+
"help": "Whether or not to upload the trained model to the model hub after training."
|
289 |
+
},
|
290 |
+
)
|
291 |
+
|
292 |
+
resume_from_checkpoint: Optional[str] = field(
|
293 |
+
default=None,
|
294 |
+
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
295 |
+
)
|
296 |
+
|
297 |
+
|
298 |
+
class TrainState(train_state.TrainState):
|
299 |
+
dropout_rng: jnp.ndarray = None
|
300 |
+
epoch: int = 0
|
301 |
+
train_time: float = 0.0 # total time the model trained
|
302 |
+
train_samples: int = 0 # number of samples seen
|
303 |
+
|
304 |
+
def replicate(self):
|
305 |
+
return jax_utils.replicate(self).replace(
|
306 |
+
dropout_rng=shard_prng_key(self.dropout_rng)
|
307 |
+
)
|
308 |
+
|
309 |
+
def restore_state(self, artifact_dir):
|
310 |
+
# restore optimizer state
|
311 |
+
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
312 |
+
new_opt_state = from_bytes(self.opt_state, f.read())
|
313 |
+
|
314 |
+
# restore other parameters
|
315 |
+
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
316 |
+
training_state = json.load(f)
|
317 |
+
|
318 |
+
# replace state
|
319 |
+
return self.replace(
|
320 |
+
opt_state=new_opt_state,
|
321 |
+
step=training_state["step"],
|
322 |
+
train_time=training_state["train_time"],
|
323 |
+
train_samples=training_state["train_samples"],
|
324 |
+
)
|
325 |
+
|
326 |
+
|
327 |
+
def create_learning_rate_fn(
|
328 |
+
num_warmup_steps: int,
|
329 |
+
learning_rate: float,
|
330 |
+
use_decay: bool,
|
331 |
+
num_train_steps: int = None, # used only with `use_decay`, typically train_size // batch_size * num_epochs
|
332 |
+
) -> Callable[[int], jnp.array]:
|
333 |
+
"""Returns a linear warmup, linear_decay learning rate function."""
|
334 |
+
if use_decay:
|
335 |
+
assert (
|
336 |
+
num_train_steps is not None
|
337 |
+
), "Learning rate with decay requires number of training steps"
|
338 |
+
warmup_fn = optax.linear_schedule(
|
339 |
+
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
340 |
+
)
|
341 |
+
if not use_decay:
|
342 |
+
return warmup_fn
|
343 |
+
decay_fn = optax.linear_schedule(
|
344 |
+
init_value=learning_rate,
|
345 |
+
end_value=0,
|
346 |
+
transition_steps=num_train_steps - num_warmup_steps,
|
347 |
+
)
|
348 |
+
schedule_fn = optax.join_schedules(
|
349 |
+
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
|
350 |
+
)
|
351 |
+
return schedule_fn
|
352 |
+
|
353 |
+
|
354 |
+
def wandb_log(metrics, step=None, prefix=None):
|
355 |
+
if jax.process_index() == 0:
|
356 |
+
log_metrics = {
|
357 |
+
f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
|
358 |
+
}
|
359 |
+
if step is not None:
|
360 |
+
log_metrics["train/step"] = step
|
361 |
+
wandb.log(log_metrics)
|
362 |
+
|
363 |
+
|
364 |
+
def main():
|
365 |
+
# See all possible arguments by passing the --help flag to this script.
|
366 |
+
parser = HfArgumentParser(
|
367 |
+
(ModelArguments, DataTrainingArguments, TrainingArguments)
|
368 |
+
)
|
369 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
370 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
371 |
+
# let's parse it to get our arguments.
|
372 |
+
model_args, data_args, training_args = parser.parse_json_file(
|
373 |
+
json_file=os.path.abspath(sys.argv[1])
|
374 |
+
)
|
375 |
+
else:
|
376 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
377 |
+
|
378 |
+
if (
|
379 |
+
os.path.exists(training_args.output_dir)
|
380 |
+
and os.listdir(training_args.output_dir)
|
381 |
+
and training_args.do_train
|
382 |
+
and not training_args.overwrite_output_dir
|
383 |
+
):
|
384 |
+
raise ValueError(
|
385 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
386 |
+
"Use --overwrite_output_dir to overcome."
|
387 |
+
)
|
388 |
+
|
389 |
+
# Make one log on every process with the configuration for debugging.
|
390 |
+
logging.basicConfig(
|
391 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
392 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
393 |
+
level=logging.INFO,
|
394 |
+
)
|
395 |
+
# Setup logging, we only want one process per machine to log things on the screen.
|
396 |
+
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
397 |
+
if jax.process_index() == 0:
|
398 |
+
datasets.utils.logging.set_verbosity_warning()
|
399 |
+
transformers.utils.logging.set_verbosity_info()
|
400 |
+
else:
|
401 |
+
datasets.utils.logging.set_verbosity_error()
|
402 |
+
transformers.utils.logging.set_verbosity_error()
|
403 |
+
|
404 |
+
# Set the verbosity to info of the Transformers logger (on main process only):
|
405 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
406 |
+
|
407 |
+
# Load dataset
|
408 |
+
dataset = Dataset(
|
409 |
+
**asdict(data_args),
|
410 |
+
do_train=training_args.do_train,
|
411 |
+
do_eval=training_args.do_eval,
|
412 |
+
)
|
413 |
+
|
414 |
+
# Set up wandb run
|
415 |
+
wandb.init(
|
416 |
+
entity="dalle-mini",
|
417 |
+
project="dalle-mini",
|
418 |
+
job_type="Seq2Seq",
|
419 |
+
config=parser.parse_args(),
|
420 |
+
)
|
421 |
+
|
422 |
+
if training_args.resume_from_checkpoint is not None:
|
423 |
+
artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
|
424 |
+
artifact_dir = artifact.download()
|
425 |
+
|
426 |
+
# load model
|
427 |
+
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
428 |
+
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
429 |
+
print(model.params)
|
430 |
+
|
431 |
+
# load tokenizer
|
432 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
433 |
+
artifact_dir,
|
434 |
+
use_fast=True,
|
435 |
+
)
|
436 |
+
|
437 |
+
else:
|
438 |
+
# Set up our new model config
|
439 |
+
# TODO: simplify with custom config class
|
440 |
+
if model_args.config_name:
|
441 |
+
config = BartConfig.from_pretrained(model_args.config_name)
|
442 |
+
else:
|
443 |
+
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
444 |
+
if model_args.image_vocab_size:
|
445 |
+
config.image_vocab_size = model_args.image_vocab_size
|
446 |
+
assert (
|
447 |
+
getattr(config, "image_vocab_size") is not None
|
448 |
+
), "image_vocab_size must be specified when not present in base model/config"
|
449 |
+
if model_args.image_length:
|
450 |
+
config.image_length = model_args.image_length
|
451 |
+
assert (
|
452 |
+
getattr(config, "image_length") is not None
|
453 |
+
), "image_length must be specified when not present in base model/config"
|
454 |
+
# we append decoder bos to image vocab
|
455 |
+
config.decoder_start_token_id = config.image_vocab_size
|
456 |
+
# ensure we don't generate bos (in addition to decoder start token)
|
457 |
+
config.force_bos_token_to_be_generated = False
|
458 |
+
config.forced_bos_token_id = None # we don't need this token
|
459 |
+
config.forced_eos_token_id = None # we don't need this token
|
460 |
+
|
461 |
+
config.tie_word_embeddings = False
|
462 |
+
config.min_length = config.image_length + 1
|
463 |
+
config.max_length = config.image_length + 1
|
464 |
+
|
465 |
+
# below tokens need to be set to avoid error during generation (converted to jnp.array)
|
466 |
+
# they are not expected to be used and are set to unreachable token id
|
467 |
+
config.bos_token_id = config.image_vocab_size + 1
|
468 |
+
config.pos_token_id = config.image_vocab_size + 1
|
469 |
+
config.eos_token_id = config.image_vocab_size + 1
|
470 |
+
|
471 |
+
# save whether we normalize the text
|
472 |
+
if model_args.normalize_text is not None:
|
473 |
+
config.normalize_text = model_args.normalize_text
|
474 |
+
else:
|
475 |
+
config.normalize_text = getattr(config, "normalize_text", False)
|
476 |
+
|
477 |
+
# Load or create new model
|
478 |
+
if model_args.model_name_or_path:
|
479 |
+
model = CustomFlaxBartForConditionalGeneration.from_pretrained(
|
480 |
+
model_args.model_name_or_path,
|
481 |
+
config=config,
|
482 |
+
seed=training_args.seed_model,
|
483 |
+
dtype=getattr(jnp, model_args.dtype),
|
484 |
+
)
|
485 |
+
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
486 |
+
print(model.params)
|
487 |
+
else:
|
488 |
+
model = CustomFlaxBartForConditionalGeneration(
|
489 |
+
config,
|
490 |
+
seed=training_args.seed_model,
|
491 |
+
dtype=getattr(jnp, model_args.dtype),
|
492 |
+
)
|
493 |
+
|
494 |
+
# Load tokenizer
|
495 |
+
if model_args.tokenizer_name is not None:
|
496 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
497 |
+
model_args.tokenizer_name, use_fast=True
|
498 |
+
)
|
499 |
+
else:
|
500 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
501 |
+
model_args.model_name_or_path,
|
502 |
+
use_fast=True,
|
503 |
+
)
|
504 |
+
|
505 |
+
logger.info(f"TPUs: {jax.device_count()}")
|
506 |
+
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
507 |
+
|
508 |
+
# Preprocessing the datasets.
|
509 |
+
# We need to normalize and tokenize inputs and targets.
|
510 |
+
|
511 |
+
dataset.preprocess(
|
512 |
+
tokenizer=tokenizer,
|
513 |
+
decoder_start_token_id=model.config.decoder_start_token_id,
|
514 |
+
normalize_text=model.config.normalize_text,
|
515 |
+
)
|
516 |
+
|
517 |
+
# Initialize our training
|
518 |
+
rng = jax.random.PRNGKey(training_args.seed_model)
|
519 |
+
rng, dropout_rng = jax.random.split(rng)
|
520 |
+
|
521 |
+
# Store some constant
|
522 |
+
num_epochs = int(training_args.num_train_epochs)
|
523 |
+
train_batch_size = (
|
524 |
+
int(training_args.per_device_train_batch_size) * jax.device_count()
|
525 |
+
)
|
526 |
+
batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
|
527 |
+
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
528 |
+
len_train_dataset, len_eval_dataset = dataset.length
|
529 |
+
steps_per_epoch = (
|
530 |
+
len_train_dataset // train_batch_size if len_train_dataset is not None else None
|
531 |
+
)
|
532 |
+
num_train_steps = (
|
533 |
+
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
|
534 |
+
)
|
535 |
+
|
536 |
+
# Create learning rate schedule
|
537 |
+
learning_rate_fn = create_learning_rate_fn(
|
538 |
+
training_args.warmup_steps,
|
539 |
+
training_args.learning_rate,
|
540 |
+
training_args.use_decay,
|
541 |
+
num_train_steps,
|
542 |
+
)
|
543 |
+
|
544 |
+
# We use Optax's "masking" functionality to not apply weight decay
|
545 |
+
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
546 |
+
# mask boolean with the same structure as the parameters.
|
547 |
+
# The mask is True for parameters that should be decayed.
|
548 |
+
# Note that this mask is specifically adapted for FlaxBart.
|
549 |
+
def decay_mask_fn(params):
|
550 |
+
flat_params = traverse_util.flatten_dict(params)
|
551 |
+
layer_norm_params = [
|
552 |
+
(name, "scale")
|
553 |
+
for name in [
|
554 |
+
"self_attn_layer_norm",
|
555 |
+
"layernorm_embedding",
|
556 |
+
"final_layer_norm",
|
557 |
+
]
|
558 |
+
]
|
559 |
+
flat_mask = {
|
560 |
+
path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
|
561 |
+
for path in flat_params
|
562 |
+
}
|
563 |
+
return traverse_util.unflatten_dict(flat_mask)
|
564 |
+
|
565 |
+
# create adam optimizer
|
566 |
+
if training_args.adafactor:
|
567 |
+
# We use the default parameters here to initialize adafactor,
|
568 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
569 |
+
optimizer = optax.adafactor(
|
570 |
+
learning_rate=learning_rate_fn,
|
571 |
+
weight_decay_rate=training_args.weight_decay,
|
572 |
+
weight_decay_mask=decay_mask_fn,
|
573 |
+
clipping_threshold=training_args.max_grad_norm,
|
574 |
+
)
|
575 |
+
else:
|
576 |
+
optimizer = optax.adamw(
|
577 |
+
learning_rate=learning_rate_fn,
|
578 |
+
b1=training_args.adam_beta1,
|
579 |
+
b2=training_args.adam_beta2,
|
580 |
+
eps=training_args.adam_epsilon,
|
581 |
+
weight_decay=training_args.weight_decay,
|
582 |
+
mask=decay_mask_fn,
|
583 |
+
)
|
584 |
+
|
585 |
+
# add gradient accumulation
|
586 |
+
if training_args.gradient_accumulation_steps > 1:
|
587 |
+
optimizer = optax.chain(
|
588 |
+
optax.apply_every(training_args.gradient_accumulation_steps), optimizer
|
589 |
+
)
|
590 |
+
|
591 |
+
# Setup train state
|
592 |
+
state = TrainState.create(
|
593 |
+
apply_fn=model.__call__,
|
594 |
+
params=model.params,
|
595 |
+
tx=optimizer,
|
596 |
+
dropout_rng=dropout_rng,
|
597 |
+
)
|
598 |
+
if training_args.resume_from_checkpoint is not None:
|
599 |
+
# restore optimizer state and other parameters
|
600 |
+
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
601 |
+
state = state.restore_state(artifact_dir)
|
602 |
+
|
603 |
+
# label smoothed cross entropy
|
604 |
+
def loss_fn(logits, labels):
|
605 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
606 |
+
loss = loss.mean()
|
607 |
+
return loss
|
608 |
+
|
609 |
+
# Define gradient update step fn
|
610 |
+
def train_step(state, batch, delta_time):
|
611 |
+
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
612 |
+
|
613 |
+
def compute_loss(params, batch):
|
614 |
+
labels = batch.pop("labels")
|
615 |
+
logits = state.apply_fn(
|
616 |
+
**batch, params=params, dropout_rng=dropout_rng, train=True
|
617 |
+
)[0]
|
618 |
+
loss = loss_fn(logits, labels)
|
619 |
+
return loss
|
620 |
+
|
621 |
+
grad_fn = jax.value_and_grad(compute_loss)
|
622 |
+
loss, grads = grad_fn(state.params, batch)
|
623 |
+
grads = jax.lax.pmean(grads, "batch")
|
624 |
+
state = state.apply_gradients(
|
625 |
+
grads=grads,
|
626 |
+
dropout_rng=new_dropout_rng,
|
627 |
+
train_time=state.train_time + delta_time,
|
628 |
+
train_samples=state.train_samples + train_batch_size,
|
629 |
+
)
|
630 |
+
|
631 |
+
metrics = {
|
632 |
+
"loss": loss,
|
633 |
+
"learning_rate": learning_rate_fn(state.step),
|
634 |
+
}
|
635 |
+
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
636 |
+
|
637 |
+
return state, metrics
|
638 |
+
|
639 |
+
# Define eval fn
|
640 |
+
def eval_step(params, batch):
|
641 |
+
labels = batch.pop("labels")
|
642 |
+
logits = model(**batch, params=params, train=False)[0]
|
643 |
+
loss = loss_fn(logits, labels)
|
644 |
+
|
645 |
+
# summarize metrics
|
646 |
+
metrics = {"loss": loss}
|
647 |
+
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
648 |
+
return metrics
|
649 |
+
|
650 |
+
# Create parallel version of the train and eval step
|
651 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
652 |
+
p_eval_step = jax.pmap(eval_step, "batch")
|
653 |
+
|
654 |
+
logger.info("***** Running training *****")
|
655 |
+
logger.info(f" Num examples = {len_train_dataset}")
|
656 |
+
logger.info(f" Num Epochs = {num_epochs}")
|
657 |
+
logger.info(
|
658 |
+
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
659 |
+
)
|
660 |
+
logger.info(
|
661 |
+
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
662 |
+
)
|
663 |
+
epochs = tqdm(
|
664 |
+
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
665 |
+
)
|
666 |
+
|
667 |
+
# set default x-axis as 'train/step'
|
668 |
+
wandb_log({}, step=state.step)
|
669 |
+
wandb.define_metric("*", step_metric="train/step")
|
670 |
+
|
671 |
+
# add interesting config parameters
|
672 |
+
wandb.config.update(
|
673 |
+
{
|
674 |
+
"len_train_dataset": len_train_dataset,
|
675 |
+
"len_eval_dataset": len_eval_dataset,
|
676 |
+
"batch_size_per_update": batch_size_per_update,
|
677 |
+
}
|
678 |
+
)
|
679 |
+
|
680 |
+
# replicate state on each device
|
681 |
+
state = state.replicate()
|
682 |
+
|
683 |
+
def run_evaluation():
|
684 |
+
# ======================== Evaluating ==============================
|
685 |
+
eval_metrics = []
|
686 |
+
if training_args.do_eval:
|
687 |
+
eval_loader = dataset.dataloader("eval", eval_batch_size)
|
688 |
+
eval_steps = (
|
689 |
+
len_eval_dataset // eval_batch_size
|
690 |
+
if len_eval_dataset is not None
|
691 |
+
else None
|
692 |
+
)
|
693 |
+
for batch in tqdm(
|
694 |
+
eval_loader,
|
695 |
+
desc="Evaluating...",
|
696 |
+
position=2,
|
697 |
+
leave=False,
|
698 |
+
total=eval_steps,
|
699 |
+
):
|
700 |
+
# Model forward
|
701 |
+
metrics = p_eval_step(state.params, batch)
|
702 |
+
eval_metrics.append(metrics)
|
703 |
+
|
704 |
+
# normalize eval metrics
|
705 |
+
eval_metrics = get_metrics(eval_metrics)
|
706 |
+
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
707 |
+
|
708 |
+
# log metrics
|
709 |
+
wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval")
|
710 |
+
|
711 |
+
# Print metrics and update progress bar
|
712 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
713 |
+
epochs.write(desc)
|
714 |
+
epochs.desc = desc
|
715 |
+
|
716 |
+
return eval_metrics
|
717 |
+
|
718 |
+
def run_save_model(state, eval_metrics=None):
|
719 |
+
if jax.process_index() == 0:
|
720 |
+
params = jax.device_get(unreplicate(state.params))
|
721 |
+
# save model locally
|
722 |
+
model.save_pretrained(
|
723 |
+
training_args.output_dir,
|
724 |
+
params=params,
|
725 |
+
)
|
726 |
+
|
727 |
+
# save tokenizer
|
728 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
729 |
+
|
730 |
+
# save state
|
731 |
+
opt_state = unreplicate(state.opt_state)
|
732 |
+
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
|
733 |
+
f.write(to_bytes(opt_state))
|
734 |
+
state_dict = {
|
735 |
+
k: jax.device_get(unreplicate(getattr(state, k))).item()
|
736 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
737 |
+
}
|
738 |
+
with (Path(training_args.output_dir) / "training_state.json").open(
|
739 |
+
"w"
|
740 |
+
) as f:
|
741 |
+
json.dump(
|
742 |
+
state_dict,
|
743 |
+
f,
|
744 |
+
)
|
745 |
+
|
746 |
+
# save to W&B
|
747 |
+
if training_args.log_model:
|
748 |
+
# save some space
|
749 |
+
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
750 |
+
c.cleanup(wandb.util.from_human_size("10GB"))
|
751 |
+
|
752 |
+
metadata = dict(state_dict)
|
753 |
+
if eval_metrics is not None:
|
754 |
+
metadata["eval"] = eval_metrics
|
755 |
+
artifact = wandb.Artifact(
|
756 |
+
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
757 |
+
)
|
758 |
+
artifact.add_file(
|
759 |
+
str(Path(training_args.output_dir) / "flax_model.msgpack")
|
760 |
+
)
|
761 |
+
artifact.add_file(str(Path(training_args.output_dir) / "config.json"))
|
762 |
+
artifact.add_file(
|
763 |
+
str(Path(training_args.output_dir) / "tokenizer.json")
|
764 |
+
)
|
765 |
+
artifact.add_file(
|
766 |
+
str(Path(training_args.output_dir) / "tokenizer_config.json")
|
767 |
+
)
|
768 |
+
artifact.add_file(str(Path(training_args.output_dir) / "vocab.json"))
|
769 |
+
artifact.add_file(str(Path(training_args.output_dir) / "merges.txt"))
|
770 |
+
artifact.add_file(
|
771 |
+
str(Path(training_args.output_dir) / "special_tokens_map.json")
|
772 |
+
)
|
773 |
+
artifact.add_file(
|
774 |
+
str(Path(training_args.output_dir) / "opt_state.msgpack")
|
775 |
+
)
|
776 |
+
artifact.add_file(
|
777 |
+
str(Path(training_args.output_dir) / "training_state.json")
|
778 |
+
)
|
779 |
+
|
780 |
+
wandb.run.log_artifact(artifact)
|
781 |
+
|
782 |
+
# save to the hub
|
783 |
+
if training_args.push_to_hub:
|
784 |
+
model.save_pretrained(
|
785 |
+
training_args.output_dir,
|
786 |
+
params=params,
|
787 |
+
push_to_hub=training_args.push_to_hub,
|
788 |
+
commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
|
789 |
+
temp_dir=True, # avoid issues with being in a repository
|
790 |
+
)
|
791 |
+
|
792 |
+
# init variables
|
793 |
+
last_time = time.perf_counter()
|
794 |
+
train_metrics = None
|
795 |
+
|
796 |
+
for epoch in epochs:
|
797 |
+
state.replace(epoch=jax_utils.replicate(epoch))
|
798 |
+
# ======================== Training ================================
|
799 |
+
wandb_log({"train/epoch": epoch}, step=unreplicate(state.step))
|
800 |
+
|
801 |
+
# Generate an epoch by shuffling sampling indices from the train dataset
|
802 |
+
train_loader = dataset.dataloader("train", train_batch_size)
|
803 |
+
# train
|
804 |
+
for batch in tqdm(
|
805 |
+
train_loader,
|
806 |
+
desc="Training...",
|
807 |
+
position=1,
|
808 |
+
leave=False,
|
809 |
+
total=steps_per_epoch,
|
810 |
+
):
|
811 |
+
|
812 |
+
# calculate delta time (we have a lag of one step but it's ok)
|
813 |
+
new_time = time.perf_counter()
|
814 |
+
delta_time = new_time - last_time
|
815 |
+
last_time = new_time
|
816 |
+
|
817 |
+
# train step
|
818 |
+
state, train_metrics = p_train_step(
|
819 |
+
state, batch, jax_utils.replicate(delta_time)
|
820 |
+
)
|
821 |
+
step = unreplicate(state.step)
|
822 |
+
|
823 |
+
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
824 |
+
# log metrics
|
825 |
+
metrics = unreplicate(train_metrics)
|
826 |
+
# log state parameters
|
827 |
+
state_dict = {
|
828 |
+
k.split("_")[-1]: unreplicate(getattr(state, k))
|
829 |
+
for k in ["epoch", "train_time", "train_samples"]
|
830 |
+
}
|
831 |
+
wandb_log({**metrics, **state_dict}, step=step, prefix="train")
|
832 |
+
|
833 |
+
eval_metrics = None
|
834 |
+
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
835 |
+
eval_metrics = run_evaluation()
|
836 |
+
|
837 |
+
if step % training_args.save_steps == 0:
|
838 |
+
run_save_model(state, eval_metrics)
|
839 |
+
|
840 |
+
# log final train metrics
|
841 |
+
if train_metrics is not None:
|
842 |
+
train_metrics = unreplicate(train_metrics)
|
843 |
+
wandb_log(train_metrics, step=step, prefix="train")
|
844 |
+
|
845 |
+
epochs.write(
|
846 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|
847 |
+
)
|
848 |
+
|
849 |
+
# Final evaluation
|
850 |
+
eval_metrics = run_evaluation()
|
851 |
+
|
852 |
+
# save checkpoint after each epoch
|
853 |
+
run_save_model(state, eval_metrics)
|
854 |
+
|
855 |
+
|
856 |
+
if __name__ == "__main__":
|
857 |
+
main()
|