teticio commited on
Commit
6dff871
1 Parent(s): f4441f8

work with grayscale images

Browse files
notebooks/test-model.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
requirements-lock.txt ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.2.0
2
+ accelerate==0.12.0
3
+ aiobotocore==2.3.4
4
+ aiohttp==3.8.1
5
+ aioitertools==0.10.0
6
+ aiosignal==1.2.0
7
+ analytics-python==1.4.0
8
+ anyio==3.6.1
9
+ appdirs==1.4.4
10
+ argon2-cffi==21.3.0
11
+ argon2-cffi-bindings==21.2.0
12
+ async-timeout==4.0.2
13
+ attrs==21.4.0
14
+ audioread==2.1.9
15
+ backcall==0.2.0
16
+ backoff==1.10.0
17
+ bcrypt==3.2.2
18
+ beautifulsoup4==4.11.1
19
+ bertviz==1.4.0
20
+ black==22.6.0
21
+ bleach==5.0.1
22
+ boto3==1.21.21
23
+ botocore==1.24.21
24
+ cachetools==5.2.0
25
+ captum==0.5.0
26
+ certifi==2022.6.15
27
+ cffi==1.15.1
28
+ charset-normalizer==2.1.0
29
+ click==8.1.3
30
+ cloudpickle==2.1.0
31
+ cryptography==37.0.4
32
+ cycler==0.11.0
33
+ datasets==2.4.0
34
+ debugpy==1.6.2
35
+ decorator==5.1.1
36
+ deepspeed==0.7.0
37
+ defusedxml==0.7.1
38
+ diffusers==0.1.3
39
+ dill==0.3.5.1
40
+ entrypoints==0.4
41
+ fastapi==0.79.0
42
+ fastjsonschema==2.16.1
43
+ ffmpy==0.3.0
44
+ filelock==3.8.0
45
+ fonttools==4.34.4
46
+ frozenlist==1.3.1
47
+ fsspec==2022.7.1
48
+ google-auth==2.10.0
49
+ google-auth-oauthlib==0.4.6
50
+ google-pasta==0.2.0
51
+ gradio==3.1.4
52
+ grpcio==1.47.0
53
+ h11==0.12.0
54
+ hjson==3.0.2
55
+ httpcore==0.15.0
56
+ httpx==0.23.0
57
+ huggingface-hub==0.8.1
58
+ idna==3.3
59
+ importlib-metadata==4.12.0
60
+ ipykernel==6.15.1
61
+ ipython==7.34.0
62
+ ipython-genutils==0.2.0
63
+ ipywidgets==7.7.1
64
+ jedi==0.18.1
65
+ Jinja2==3.1.2
66
+ jmespath==1.0.1
67
+ joblib==1.1.0
68
+ jsonschema==4.9.1
69
+ jupyter-client==7.3.4
70
+ jupyter-core==4.11.1
71
+ jupyterlab-pygments==0.2.2
72
+ jupyterlab-widgets==1.1.1
73
+ kiwisolver==1.4.4
74
+ librosa==0.9.2
75
+ linkify-it-py==1.0.3
76
+ llvmlite==0.39.0
77
+ lxml==4.9.1
78
+ Markdown==3.4.1
79
+ markdown-it-py==2.1.0
80
+ MarkupSafe==2.1.1
81
+ matplotlib==3.5.2
82
+ matplotlib-inline==0.1.3
83
+ mdit-py-plugins==0.3.0
84
+ mdurl==0.1.1
85
+ mistune==0.8.4
86
+ monotonic==1.6
87
+ more-itertools==8.14.0
88
+ multidict==6.0.2
89
+ multiprocess==0.70.13
90
+ munkres==1.1.4
91
+ mypy-extensions==0.4.3
92
+ nbclient==0.6.6
93
+ nbconvert==6.5.1
94
+ nbformat==5.4.0
95
+ nest-asyncio==1.5.5
96
+ networkx==2.8.5
97
+ ninja==1.10.2.3
98
+ nlp==0.4.0
99
+ nltk==3.7
100
+ notebook==6.4.12
101
+ numba==0.56.0
102
+ numpy==1.22.4
103
+ oauthlib==3.2.0
104
+ orjson==3.7.11
105
+ packaging==21.3
106
+ pandas==1.4.3
107
+ pandocfilters==1.5.0
108
+ paramiko==2.11.0
109
+ parso==0.8.3
110
+ pathos==0.2.9
111
+ pathspec==0.9.0
112
+ pexpect==4.8.0
113
+ pickleshare==0.7.5
114
+ Pillow==9.2.0
115
+ platformdirs==2.5.2
116
+ pluggy==0.13.1
117
+ pooch==1.6.0
118
+ pox==0.3.1
119
+ ppft==1.7.6.5
120
+ prometheus-client==0.14.1
121
+ prompt-toolkit==3.0.30
122
+ protobuf==3.19.4
123
+ protobuf3-to-dict==0.1.5
124
+ psutil==5.9.1
125
+ ptyprocess==0.7.0
126
+ py==1.11.0
127
+ py-cpuinfo==8.0.0
128
+ pyarrow==9.0.0
129
+ pyasn1==0.4.8
130
+ pyasn1-modules==0.2.8
131
+ pycparser==2.21
132
+ pycryptodome==3.15.0
133
+ pydantic==1.9.1
134
+ pydub==0.25.1
135
+ Pygments==2.12.0
136
+ PyNaCl==1.5.0
137
+ pyparsing==3.0.9
138
+ pyrsistent==0.18.1
139
+ pytest==5.4.3
140
+ python-dateutil==2.8.2
141
+ python-dotenv==0.20.0
142
+ python-multipart==0.0.5
143
+ pytz==2022.1
144
+ PyYAML==6.0
145
+ pyzmq==23.2.0
146
+ regex==2022.7.25
147
+ requests==2.28.1
148
+ requests-oauthlib==1.3.1
149
+ resampy==0.4.0
150
+ responses==0.18.0
151
+ rfc3986==1.5.0
152
+ rsa==4.9
153
+ s3fs==2022.7.1
154
+ s3transfer==0.5.2
155
+ sagemaker==2.103.0
156
+ scikit-learn==1.1.2
157
+ scipy==1.9.0
158
+ seaborn==0.11.2
159
+ Send2Trash==1.8.0
160
+ sentencepiece==0.1.97
161
+ shap==0.41.0
162
+ six==1.16.0
163
+ slicer==0.0.7
164
+ smdebug-rulesconfig==1.0.1
165
+ sniffio==1.2.0
166
+ snorkel==0.9.9
167
+ SoundFile==0.10.3.post1
168
+ soupsieve==2.3.2.post1
169
+ starlette==0.19.1
170
+ tensorboard==2.9.1
171
+ tensorboard-data-server==0.6.1
172
+ tensorboard-plugin-wit==1.8.1
173
+ terminado==0.15.0
174
+ threadpoolctl==3.1.0
175
+ tinycss2==1.1.1
176
+ tokenizers==0.12.1
177
+ toml==0.10.2
178
+ tomli==2.0.1
179
+ torch==1.12.1
180
+ torchvision==0.13.1
181
+ tornado==6.2
182
+ tqdm==4.64.0
183
+ traitlets==5.3.0
184
+ transformers==4.21.1
185
+ transformers-interpret==0.7.5
186
+ typing_extensions==4.3.0
187
+ uc-micro-py==1.0.1
188
+ urllib3==1.26.11
189
+ uvicorn==0.18.2
190
+ wcwidth==0.2.5
191
+ webencodings==0.5.1
192
+ Werkzeug==2.2.2
193
+ widgetsnbextension==3.6.1
194
+ wrapt==1.14.1
195
+ xxhash==3.0.0
196
+ yapf==0.32.0
197
+ yarl==1.8.1
198
+ zipp==3.8.1
src/audio_to_images.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import os
2
  import re
3
  import io
 
1
+ # TODO
2
+ # run on sagemaker
3
+ # run with deepspeed
4
+
5
+
6
  import os
7
  import re
8
  import io
src/train_unconditional.py CHANGED
@@ -39,8 +39,8 @@ def main(args):
39
 
40
  model = UNet2DModel(
41
  sample_size=args.resolution,
42
- in_channels=3,
43
- out_channels=3,
44
  layers_per_block=2,
45
  block_out_channels=(128, 128, 256, 256, 512, 512),
46
  down_block_types=(
@@ -101,7 +101,7 @@ def main(args):
101
  )
102
 
103
  def transforms(examples):
104
- images = [augmentations(image.convert("RGB")) for image in examples["image"]]
105
  return {"input": images}
106
 
107
  dataset.set_transform(transforms)
@@ -215,8 +215,7 @@ def main(args):
215
  "test_samples", images_processed, epoch
216
  )
217
  for _, image in enumerate(images_processed):
218
- image = Image.fromarray(np.mean(image, axis=0).astype("uint8"))
219
- audio = mel.image_to_audio(image)
220
  accelerator.trackers[0].writer.add_audio(
221
  f"test_audio_{_}",
222
  audio,
 
39
 
40
  model = UNet2DModel(
41
  sample_size=args.resolution,
42
+ in_channels=1,
43
+ out_channels=1,
44
  layers_per_block=2,
45
  block_out_channels=(128, 128, 256, 256, 512, 512),
46
  down_block_types=(
 
101
  )
102
 
103
  def transforms(examples):
104
+ images = [augmentations(image) for image in examples["image"]]
105
  return {"input": images}
106
 
107
  dataset.set_transform(transforms)
 
215
  "test_samples", images_processed, epoch
216
  )
217
  for _, image in enumerate(images_processed):
218
+ audio = mel.image_to_audio(Image.fromarray(image[0]))
 
219
  accelerator.trackers[0].writer.add_audio(
220
  f"test_audio_{_}",
221
  audio,